• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.client.session.Session."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import random
22import os
23import sys
24import threading
25import time
26import warnings
27
28import numpy as np
29import six
30from six.moves import xrange  # pylint: disable=redefined-builtin
31
32from tensorflow.core.framework import attr_value_pb2
33from tensorflow.core.lib.core import error_codes_pb2
34from tensorflow.core.protobuf import config_pb2
35from tensorflow.python.client import session
36from tensorflow.python.eager import context
37from tensorflow.python.framework import common_shapes
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import device as framework_device_lib
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import errors
42from tensorflow.python.framework import function
43from tensorflow.python.framework import importer
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import sparse_tensor
46from tensorflow.python.framework import tensor_util
47from tensorflow.python.framework import test_util
48from tensorflow.python.framework import versions
49from tensorflow.python.ops import array_ops
50from tensorflow.python.ops import control_flow_ops
51from tensorflow.python.ops import data_flow_ops
52from tensorflow.python.ops import gen_control_flow_ops
53# Import gradients to resolve circular imports
54from tensorflow.python.ops import gradients  # pylint: disable=unused-import
55from tensorflow.python.ops import gradients_impl
56from tensorflow.python.ops import math_ops
57# Import resource_variable_ops for the variables-to-tensor implicit conversion.
58from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
59from tensorflow.python.ops import state_ops
60from tensorflow.python.ops import variables
61from tensorflow.python.platform import googletest
62from tensorflow.python.training import server_lib
63from tensorflow.python.util import compat
64
65try:
66  import attr  # pylint:disable=g-import-not-at-top
67except ImportError:
68  attr = None
69
70
71# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
72# don't have C++ op registrations on which to attach C++ shape fns.
73ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
74
75
76class SessionTest(test_util.TensorFlowTestCase):
77
78  def setUp(self):
79    super(SessionTest, self).setUp()
80    warnings.simplefilter('always')
81
82  def testUseExistingGraph(self):
83    with ops.Graph().as_default() as g, ops.device('/cpu:0'):
84      a = constant_op.constant(6.0, shape=[1, 1])
85      b = constant_op.constant(7.0, shape=[1, 1])
86      c = math_ops.matmul(a, b, name='matmul')
87    with session.Session(graph=g):
88      result = c.eval()
89      self.assertAllEqual(result, [[42.0]])
90
91  def testUseDefaultGraph(self):
92    with ops.Graph().as_default(), ops.device('/cpu:0'):
93      a = constant_op.constant(6.0, shape=[1, 1])
94      b = constant_op.constant(7.0, shape=[1, 1])
95      c = math_ops.matmul(a, b, name='matmul')
96      with session.Session():
97        result = c.eval()
98        self.assertAllEqual(result, [[42.0]])
99
100  def testCreate(self):
101    with session.Session():
102      inp = constant_op.constant(10.0, shape=[2, 3], name='W1')
103      copy = array_ops.identity(inp)
104      # Test with feed.
105      # TODO(mrry): Investigate why order='F' didn't work.
106      arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C')
107      copy_val = copy.eval({'W1:0': arr})
108      self.assertAllEqual(arr, copy_val)
109      # Test without feed.
110      copy_val = copy.eval()
111      self.assertAllEqual(
112          np.asarray(
113              [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32),
114          copy_val)
115
116  def testManyCPUs(self):
117    with session.Session(
118        config=config_pb2.ConfigProto(device_count={
119            'CPU': 2, 'GPU': 0
120        })) as sess:
121      inp = constant_op.constant(10.0, name='W1')
122      self.assertAllEqual(inp.eval(), 10.0)
123
124      num_cpu_devices = 0
125      num_gpu_devices = 0
126      for device in sess.list_devices():
127        device_type = framework_device_lib.DeviceSpec.from_string(
128            device.name).device_type
129        if device_type == 'CPU':
130          num_cpu_devices += 1
131        elif device_type == 'GPU':
132          num_gpu_devices += 1
133      self.assertEqual(2, num_cpu_devices)
134      self.assertEqual(0, num_gpu_devices)
135
136  def testPerSessionThreads(self):
137    with session.Session(
138        config=config_pb2.ConfigProto(use_per_session_threads=True)):
139      inp = constant_op.constant(10.0, name='W1')
140      self.assertAllEqual(inp.eval(), 10.0)
141
142  def testSessionInterOpThreadPool(self):
143    config = config_pb2.ConfigProto()
144    pool = config.session_inter_op_thread_pool.add()
145    with session.Session(config=config) as s:
146      inp = constant_op.constant(10.0, name='W1')
147      results = s.run([inp])
148      self.assertAllEqual([10.0], results)
149
150    pool = config.session_inter_op_thread_pool.add()
151    pool.num_threads = 1
152    with session.Session(config=config) as s:
153      inp = constant_op.constant(20.0, name='W2')
154      results = s.run([inp])
155      self.assertAllEqual([20.0], results)
156
157    pool = config.session_inter_op_thread_pool.add()
158    pool.num_threads = 1
159    pool.global_name = 't1'
160    run_options = config_pb2.RunOptions()
161    run_options.inter_op_thread_pool = (
162        len(config.session_inter_op_thread_pool) - 1)
163    with session.Session(config=config) as s:
164      inp = constant_op.constant(30.0, name='W2')
165      results = s.run([inp], options=run_options)
166      self.assertAllEqual([30.0], results)
167
168  def testErrorsReported(self):
169    with session.Session() as s:
170      constant_op.constant(10.0, name='W1')
171      with self.assertRaises(ValueError):
172        s.run('foo:0')
173
174  def testErrorPayload(self):
175    with session.Session():
176      a = array_ops.placeholder(dtypes.float32)
177      with self.assertRaisesOpError(lambda e: e.op == a.op):
178        a.eval()
179
180  def testErrorCodeWithNoNodeDef(self):
181    with session.Session() as s:
182      a = array_ops.placeholder(dtypes.float32, shape=[])
183      b = array_ops.placeholder(dtypes.float32, shape=[])
184      r1 = math_ops.add(a, b)
185
186      def exc_predicate(e):
187        return (e.op is None and e.node_def is None and
188                e.error_code == error_codes_pb2.INVALID_ARGUMENT)
189
190      with self.assertRaisesOpError(exc_predicate):
191        # Run with a bogus handle.
192        s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
193
194  def testErrorBasedOn(self):
195    with session.Session() as sess:
196      a = constant_op.constant(0.0, shape=[2, 3])
197      # NOTE(mrry): The original_op is nonsense, but used here to test that the
198      #   errors are reported correctly.
199      with sess.graph._original_op(a.op):
200        b = array_ops.identity(a, name='id')
201      with sess.graph._original_op(b.op):
202        c = array_ops.placeholder(dtypes.float32)
203
204      def exc_predicate(e):
205        return (e.op == c.op and e.op._original_op == b.op and
206                e.op._original_op._original_op == a.op)
207
208      with self.assertRaisesOpError(exc_predicate):
209        c.eval()
210
211  def testFetchNone(self):
212    with session.Session() as s:
213      a = constant_op.constant(1.0)
214      with self.assertRaises(TypeError):
215        s.run(None)
216      with self.assertRaises(TypeError):
217        s.run([None])
218      with self.assertRaises(TypeError):
219        s.run({'b': None})
220      with self.assertRaises(TypeError):
221        s.run({'a': a, 'b': None})
222
223  def testFetchSingleton(self):
224    with session.Session() as sess:
225      a = constant_op.constant(42.0)
226      res = sess.run(a)
227      self.assertEqual(42.0, res)
228      res = sess.run(a.op)  # An op, not a tensor.
229      self.assertEqual(None, res)
230      tensor_runner = sess.make_callable(a)
231      res = tensor_runner()
232      self.assertEqual(42.0, res)
233      op_runner = sess.make_callable(a.op)
234      res = op_runner()
235      self.assertEqual(None, res)
236
237  def testFetchSingletonByName(self):
238    with session.Session() as sess:
239      a = constant_op.constant(42.0)
240      res = sess.run(a.name)
241      self.assertEqual(42.0, res)
242      res = sess.run(a.op)  # An op, not a tensor.
243      self.assertEqual(None, res)
244
245  def testFetchList(self):
246    with session.Session() as sess:
247      a = constant_op.constant(42.0)
248      b = control_flow_ops.no_op()  # An op, not a tensor.
249      c = constant_op.constant(44.0)
250      v = variables.Variable([54.0])
251      assign = v.assign([63.0])
252      res = sess.run([a, b, c, a.name, assign.op])
253      self.assertTrue(isinstance(res, list))
254      self.assertEqual([42.0, None, 44.0, 42.0, None], res)
255      list_runner = sess.make_callable([a, b, c, a.name, assign.op])
256      res = list_runner()
257      self.assertTrue(isinstance(res, list))
258      self.assertEqual([42.0, None, 44.0, 42.0, None], res)
259
260  def testFetchTuple(self):
261    with session.Session() as sess:
262      a = constant_op.constant(42.0)
263      b = control_flow_ops.no_op()  # An op, not a tensor.
264      c = constant_op.constant(44.0)
265      res = sess.run((a, b, c, a.name))
266      self.assertTrue(isinstance(res, tuple))
267      self.assertEqual((42.0, None, 44.0, 42.0), res)
268      tuple_runner = sess.make_callable((a, b, c, a.name))
269      res = tuple_runner()
270      self.assertTrue(isinstance(res, tuple))
271      self.assertEqual((42.0, None, 44.0, 42.0), res)
272
273  def testFetchNamedTuple(self):
274    # pylint: disable=invalid-name
275    ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
276    # pylint: enable=invalid-name
277    with session.Session() as sess:
278      a = constant_op.constant(42.0)
279      b = control_flow_ops.no_op()  # An op, not a tensor.
280      c = constant_op.constant(44.0)
281      res = sess.run(ABC(a, b, c))
282      self.assertTrue(isinstance(res, ABC))
283      self.assertEqual(42.0, res.a)
284      self.assertEqual(None, res.b)
285      self.assertEqual(44.0, res.c)
286      namedtuple_runner = sess.make_callable(ABC(a, b, c))
287      res = namedtuple_runner()
288      self.assertTrue(isinstance(res, ABC))
289      self.assertEqual(42.0, res.a)
290      self.assertEqual(None, res.b)
291      self.assertEqual(44.0, res.c)
292
293  def testFetchDict(self):
294    with session.Session() as sess:
295      a = constant_op.constant(42.0)
296      b = control_flow_ops.no_op()  # An op, not a tensor.
297      c = constant_op.constant(44.0)
298      res = sess.run({'a': a, 'b': b, 'c': c})
299      self.assertTrue(isinstance(res, dict))
300      self.assertEqual(42.0, res['a'])
301      self.assertEqual(None, res['b'])
302      self.assertEqual(44.0, res['c'])
303
304  def testFetchOrderedDict(self):
305    with session.Session() as sess:
306      a = constant_op.constant(42.0)
307      b = control_flow_ops.no_op()  # An op, not a tensor.
308      c = constant_op.constant(44.0)
309      res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)]))
310      self.assertTrue(isinstance(res, collections.OrderedDict))
311      self.assertEqual([3, 2, 1], list(res.keys()))
312      self.assertEqual(42.0, res[3])
313      self.assertEqual(None, res[2])
314      self.assertEqual(44.0, res[1])
315
316  @test_util.run_v1_only('b/120545219')
317  def testFetchAttrs(self):
318    if attr is None:
319      self.skipTest('attr module is unavailable.')
320
321    @attr.s
322    class SampleAttr(object):
323      field1 = attr.ib()
324      field2 = attr.ib()
325
326    val1 = np.array([1.2, 3.4, 5.6])
327    val2 = np.array([[1, 2], [4, 3]])
328    val3 = np.array([10, 20, 30])
329
330    t1 = constant_op.constant(val1)
331    t2 = constant_op.constant(val2)
332
333    sample = SampleAttr(t1, t2)
334    with session.Session() as sess:
335      result = sess.run(sample)
336      self.assertIsInstance(result, SampleAttr)
337      self.assertAllEqual(val1, result.field1)
338      self.assertAllEqual(val2, result.field2)
339
340      result = sess.run(sample, feed_dict={sample.field1: val3})
341      self.assertIsInstance(result, SampleAttr)
342      self.assertAllEqual(val3, result.field1)
343      self.assertAllEqual(val2, result.field2)
344
345  @test_util.run_v1_only('b/120545219')
346  def testFetchNestedAttrs(self):
347    if attr is None:
348      self.skipTest('attr module is unavailable.')
349
350    @attr.s
351    class SampleAttr(object):
352      field0 = attr.ib()
353      field1 = attr.ib()
354
355    v1 = 10
356    v2 = 20
357    v3 = np.float32(1.2)
358    v4 = np.float32(3.4)
359    v5 = np.float64(100.001)
360    v6 = np.float64(-23.451)
361    arr1 = np.array([1.2, 6.7, 3.4])
362    arr2 = np.array([7, 11, 3])
363    sample = SampleAttr(
364        SampleAttr(
365            SampleAttr(constant_op.constant(v1), constant_op.constant(v2)),
366            SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))),
367        {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)),
368         'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]})
369
370    with session.Session() as sess:
371      result = sess.run(sample)
372      self.assertIsInstance(result, SampleAttr)
373      self.assertIsInstance(result.field0, SampleAttr)
374      self.assertIsInstance(result.field0.field0, SampleAttr)
375      self.assertIsInstance(result.field0.field1, SampleAttr)
376      self.assertIsInstance(result.field0.field1.field0, np.ndarray)
377      self.assertAllEqual(arr1, result.field0.field1.field0)
378      self.assertIsInstance(result.field0.field1.field1, np.ndarray)
379      self.assertAllEqual(arr2, result.field0.field1.field1)
380      self.assertIsInstance(result.field1, dict)
381      self.assertIn('A', result.field1)
382      self.assertIn('B', result.field1)
383      self.assertIsInstance(result.field1['A'], SampleAttr)
384      self.assertAllEqual(
385          [v3, v4],
386          [result.field1['A'].field0, result.field1['A'].field1])
387      self.assertIsInstance(result.field1['B'], list)
388      self.assertEqual(1, len(result.field1['B']))
389      self.assertIsInstance(result.field1['B'][0], SampleAttr)
390      self.assertAllEqual(
391          [v5, v6],
392          [result.field1['B'][0].field0, result.field1['B'][0].field1])
393
394  def testFetchNestingEmptyOneLevel(self):
395    with session.Session() as sess:
396      a_val = 11.0
397      a = constant_op.constant(a_val)
398
399      res = sess.run([[], tuple(), {}])
400      self.assertTrue(isinstance(res, list))
401      self.assertEquals(3, len(res))
402      self.assertTrue(isinstance(res[0], list))
403      self.assertEqual(0, len(res[0]))
404      self.assertTrue(isinstance(res[1], tuple))
405      self.assertEqual(0, len(res[1]))
406      self.assertTrue(isinstance(res[2], dict))
407      self.assertEqual(0, len(res[2]))
408
409      res = sess.run([[], tuple(), {}, a])
410      self.assertTrue(isinstance(res, list))
411      self.assertEquals(4, len(res))
412      self.assertTrue(isinstance(res[0], list))
413      self.assertEqual(0, len(res[0]))
414      self.assertTrue(isinstance(res[1], tuple))
415      self.assertEqual(0, len(res[1]))
416      self.assertTrue(isinstance(res[2], dict))
417      self.assertEqual(0, len(res[2]))
418      self.assertEqual(a_val, res[3])
419
420  def testFetchNestingOneLevel(self):
421    with session.Session() as sess:
422      # pylint: disable=invalid-name
423      ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
424      DEFG = collections.namedtuple('DEFG', ['d', 'e', 'f', 'g'])
425      # pylint: enable=invalid-name
426      a_val = 42.0
427      b_val = None
428      c_val = 44.0
429      a = constant_op.constant(a_val)
430      b = control_flow_ops.no_op()  # An op, not a tensor.
431      c = constant_op.constant(c_val)
432      # List of lists, tuples, namedtuple, and dict
433      res = sess.run([[a, b, c], (a, b, c),
434                      ABC(a=a, b=b, c=c), {
435                          'a': a.name,
436                          'c': c,
437                          'b': b
438                      }])
439      self.assertTrue(isinstance(res, list))
440      self.assertEqual(4, len(res))
441      self.assertTrue(isinstance(res[0], list))
442      self.assertEqual(3, len(res[0]))
443      self.assertEqual(a_val, res[0][0])
444      self.assertEqual(b_val, res[0][1])
445      self.assertEqual(c_val, res[0][2])
446      self.assertTrue(isinstance(res[1], tuple))
447      self.assertEqual(3, len(res[1]))
448      self.assertEqual(a_val, res[1][0])
449      self.assertEqual(b_val, res[1][1])
450      self.assertEqual(c_val, res[1][2])
451      self.assertTrue(isinstance(res[2], ABC))
452      self.assertEqual(a_val, res[2].a)
453      self.assertEqual(b_val, res[2].b)
454      self.assertEqual(c_val, res[2].c)
455      self.assertTrue(isinstance(res[3], dict))
456      self.assertEqual(3, len(res[3]))
457      self.assertEqual(a_val, res[3]['a'])
458      self.assertEqual(b_val, res[3]['b'])
459      self.assertEqual(c_val, res[3]['c'])
460      # Tuple of lists, tuples, namedtuple, and dict
461      res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), {
462          'a': a,
463          'c': c,
464          'b': b
465      }))
466      self.assertTrue(isinstance(res, tuple))
467      self.assertEqual(4, len(res))
468      self.assertTrue(isinstance(res[0], list))
469      self.assertEqual(3, len(res[0]))
470      self.assertEqual(a_val, res[0][0])
471      self.assertEqual(b_val, res[0][1])
472      self.assertEqual(c_val, res[0][2])
473      self.assertTrue(isinstance(res[1], tuple))
474      self.assertEqual(3, len(res[1]))
475      self.assertEqual(a_val, res[1][0])
476      self.assertEqual(b_val, res[1][1])
477      self.assertEqual(c_val, res[1][2])
478      self.assertTrue(isinstance(res[2], ABC))
479      self.assertEqual(a_val, res[2].a)
480      self.assertEqual(b_val, res[2].b)
481      self.assertEqual(c_val, res[2].c)
482      self.assertTrue(isinstance(res[3], dict))
483      self.assertEqual(3, len(res[3]))
484      self.assertEqual(a_val, res[3]['a'])
485      self.assertEqual(b_val, res[3]['b'])
486      self.assertEqual(c_val, res[3]['c'])
487      # Namedtuple of lists, tuples, namedtuples, and dict
488      res = sess.run(
489          DEFG(
490              d=[a, b, c],
491              e=(a, b, c),
492              f=ABC(a=a.name, b=b, c=c),
493              g={
494                  'a': a,
495                  'c': c,
496                  'b': b
497              }))
498      self.assertTrue(isinstance(res, DEFG))
499      self.assertTrue(isinstance(res.d, list))
500      self.assertEqual(3, len(res.d))
501      self.assertEqual(a_val, res.d[0])
502      self.assertEqual(b_val, res.d[1])
503      self.assertEqual(c_val, res.d[2])
504      self.assertTrue(isinstance(res.e, tuple))
505      self.assertEqual(3, len(res.e))
506      self.assertEqual(a_val, res.e[0])
507      self.assertEqual(b_val, res.e[1])
508      self.assertEqual(c_val, res.e[2])
509      self.assertTrue(isinstance(res.f, ABC))
510      self.assertEqual(a_val, res.f.a)
511      self.assertEqual(b_val, res.f.b)
512      self.assertEqual(c_val, res.f.c)
513      self.assertTrue(isinstance(res.g, dict))
514      self.assertEqual(3, len(res.g))
515      self.assertEqual(a_val, res.g['a'])
516      self.assertEqual(b_val, res.g['b'])
517      self.assertEqual(c_val, res.g['c'])
518      # Dict of lists, tuples, namedtuples, and dict
519      res = sess.run({
520          'd': [a, b, c],
521          'e': (a, b, c),
522          'f': ABC(a=a, b=b, c=c),
523          'g': {
524              'a': a.name,
525              'c': c,
526              'b': b
527          }
528      })
529      self.assertTrue(isinstance(res, dict))
530      self.assertEqual(4, len(res))
531      self.assertTrue(isinstance(res['d'], list))
532      self.assertEqual(3, len(res['d']))
533      self.assertEqual(a_val, res['d'][0])
534      self.assertEqual(b_val, res['d'][1])
535      self.assertEqual(c_val, res['d'][2])
536      self.assertTrue(isinstance(res['e'], tuple))
537      self.assertEqual(3, len(res['e']))
538      self.assertEqual(a_val, res['e'][0])
539      self.assertEqual(b_val, res['e'][1])
540      self.assertEqual(c_val, res['e'][2])
541      self.assertTrue(isinstance(res['f'], ABC))
542      self.assertEqual(a_val, res['f'].a)
543      self.assertEqual(b_val, res['f'].b)
544      self.assertEqual(c_val, res['f'].c)
545      self.assertTrue(isinstance(res['g'], dict))
546      self.assertEqual(3, len(res['g']))
547      self.assertEqual(a_val, res['g']['a'])
548      self.assertEqual(b_val, res['g']['b'])
549      self.assertEqual(c_val, res['g']['c'])
550
551  def testFetchTensorObject(self):
552    with session.Session() as s:
553      a = constant_op.constant(1.0, shape=[1, 2])
554      b = constant_op.constant(2.0, shape=[2, 3])
555      c = math_ops.matmul(a, b)
556      results_with_list = s.run([c])
557      self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0])
558      results_with_single = s.run(c)
559      self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single)
560      results_with_get = c.eval()
561      self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get)
562      a_val, b_val = s.run([a, b])  # Test multiple fetches.
563      self.assertAllEqual([[1.0, 1.0]], a_val)
564      self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val)
565      results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]})
566      self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0])
567      self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
568                          results_with_dict['b'])
569      self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0])
570      self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1])
571
572      # Test nested structures
573      results_with_nested_list = s.run([[[a, b], b], a, [a, b]])
574      self.assertAllEqual([[1.0, 1.0]], results_with_nested_list[0][0][0])
575      self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
576                          results_with_nested_list[0][0][1])
577      self.assertAllEqual(results_with_nested_list[0][0][0],
578                          results_with_nested_list[1])
579      self.assertAllEqual(results_with_nested_list[1],
580                          results_with_nested_list[2][0])
581      self.assertAllEqual(results_with_nested_list[0][0][1],
582                          results_with_nested_list[0][1])
583      self.assertAllEqual(results_with_nested_list[0][1],
584                          results_with_nested_list[2][1])
585
586  def testFetchScalar(self):
587    with session.Session() as s:
588      for scalar in np.int32, np.int64, np.float16, np.float32, np.float64:
589        x = scalar(7)
590        y = scalar(8)
591        tf_x = constant_op.constant(x, shape=[])
592        tf_y = constant_op.constant(y)
593        tf_xy = math_ops.add(tf_x, tf_y)
594        # Single fetch
595        xy = s.run(tf_xy)
596        self.assertEqual(scalar, type(xy))
597        self.assertEqual(x + y, xy)
598        # List fetch
599        xy, = s.run([tf_xy])
600        self.assertEqual(scalar, type(xy))
601        self.assertEqual(x + y, xy)
602        # Dict fetch
603        xy = s.run({'xy': tf_xy})['xy']
604        self.assertEqual(scalar, type(xy))
605        self.assertEqual(x + y, xy)
606        # Nested list fetch
607        xy = s.run([[[tf_xy]], tf_xy, [tf_xy]])
608        self.assertAllEqual(xy, [[[x + y]], x + y, [x + y]])
609        self.assertEqual(scalar, type(xy[0][0][0]))
610        self.assertEqual(scalar, type(xy[1]))
611        self.assertEqual(scalar, type(xy[2][0]))
612
613  def testFetchOperationObject(self):
614    with session.Session() as s:
615      a = constant_op.constant(1.0, shape=[1, 2])
616      v = variables.Variable(a, name='testFetchOperationObject_v')
617      s.run(v.initializer)
618      v_val = s.run(v)
619      self.assertAllEqual([[1.0, 1.0]], v_val)
620
621  def testFetchSparseTensor(self):
622    with session.Session() as s:
623      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
624      values = np.array([1.0, 2.0]).astype(np.float32)
625      shape = np.array([7, 9, 2]).astype(np.int64)
626      sp = sparse_tensor.SparseTensor(
627          constant_op.constant(indices), constant_op.constant(values),
628          constant_op.constant(shape))
629      # Single fetch, use as tuple
630      sp_out = s.run(sp)
631      indices_out, values_out, shape_out = sp_out
632      self.assertAllEqual(indices_out, indices)
633      self.assertAllEqual(values_out, values)
634      self.assertAllEqual(shape_out, shape)
635      # Single fetch, use as SparseTensorValue
636      sp_out = s.run(sp)
637      self.assertAllEqual(sp_out.indices, indices)
638      self.assertAllEqual(sp_out.values, values)
639      self.assertAllEqual(sp_out.dense_shape, shape)
640      # Tuple fetch, use as tuple
641      indices_out, values_out, shape_out = s.run(sp)
642      self.assertAllEqual(indices_out, indices)
643      self.assertAllEqual(values_out, values)
644      self.assertAllEqual(shape_out, shape)
645      # List fetch, use as tuple
646      (indices_out, values_out, shape_out), = s.run([sp])
647      self.assertAllEqual(indices_out, indices)
648      self.assertAllEqual(values_out, values)
649      self.assertAllEqual(shape_out, shape)
650      # List fetch, use as SparseTensorValue
651      sp_out, = s.run([sp])
652      self.assertAllEqual(sp_out.indices, indices)
653      self.assertAllEqual(sp_out.values, values)
654      self.assertAllEqual(sp_out.dense_shape, shape)
655      # Dict fetch (single value), use as tuple
656      indices_out, values_out, shape_out = s.run({'sp': sp})['sp']
657      self.assertAllEqual(indices_out, indices)
658      self.assertAllEqual(values_out, values)
659      self.assertAllEqual(shape_out, shape)
660      # Dict fetch (list value), use as tuple
661      (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp']
662      self.assertAllEqual(indices_out, indices)
663      self.assertAllEqual(values_out, values)
664      self.assertAllEqual(shape_out, shape)
665      # Dict fetch, use as SparseTensorValue
666      sp_out = s.run({'sp': sp})['sp']
667      self.assertAllEqual(sp_out.indices, indices)
668      self.assertAllEqual(sp_out.values, values)
669      self.assertAllEqual(sp_out.dense_shape, shape)
670      # Nested list fetch use as tuple
671      sp_out = s.run([[[sp]], sp])
672      indices_out, values_out, shape_out = sp_out[0][0][0]
673      self.assertAllEqual(indices_out, indices)
674      self.assertAllEqual(values_out, values)
675      self.assertAllEqual(shape_out, shape)
676      indices_out, values_out, shape_out = sp_out[1]
677      self.assertAllEqual(indices_out, indices)
678      self.assertAllEqual(values_out, values)
679      self.assertAllEqual(shape_out, shape)
680      # Nested list fetch, use as SparseTensorValue
681      sp_out = s.run([[[sp]], sp])
682      self.assertAllEqual(sp_out[0][0][0].indices, indices)
683      self.assertAllEqual(sp_out[0][0][0].values, values)
684      self.assertAllEqual(sp_out[0][0][0].dense_shape, shape)
685      self.assertAllEqual(sp_out[1].indices, indices)
686      self.assertAllEqual(sp_out[1].values, values)
687      self.assertAllEqual(sp_out[1].dense_shape, shape)
688
689  def testFeedSparseTensor(self):
690    with session.Session() as s:
691      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
692      values = np.array([1.0, 2.0]).astype(np.float32)
693      shape = np.array([7, 9, 2]).astype(np.int64)
694      sp = sparse_tensor.SparseTensor(
695          array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
696          array_ops.placeholder(dtype=np.float32, shape=(2,)),
697          array_ops.placeholder(dtype=np.int64, shape=(3,)),
698      )
699      sp_indices = array_ops.identity(sp.indices)
700      sp_values = array_ops.identity(sp.values)
701      sp_shape = array_ops.identity(sp.dense_shape)
702      sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
703      # Feed with tuple
704      indices_out, values_out, shape_out = s.run(
705          [sp_indices, sp_values, sp_shape], {
706              sp: (indices, values, shape)
707          })
708      self.assertAllEqual(indices_out, indices)
709      self.assertAllEqual(values_out, values)
710      self.assertAllEqual(shape_out, shape)
711      # Feed with tuple, fetch sp directly
712      sp_out = s.run(sp, {sp: (indices, values, shape)})
713      self.assertAllEqual(sp_out.indices, indices)
714      self.assertAllEqual(sp_out.values, values)
715      self.assertAllEqual(sp_out.dense_shape, shape)
716      # Feed with SparseTensorValue
717      indices_out, values_out, shape_out = s.run(
718          [sp_indices, sp_values, sp_shape], {
719              sp: sparse_tensor.SparseTensorValue(indices, values, shape)
720          })
721      self.assertAllEqual(indices_out, indices)
722      self.assertAllEqual(values_out, values)
723      self.assertAllEqual(shape_out, shape)
724      # Feed with SparseTensorValue, fetch SparseTensorValue
725      sp2_out = s.run(sp2, {
726          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
727      })
728      self.assertAllEqual(sp2_out.indices, indices)
729      self.assertAllEqual(sp2_out.values, values)
730      self.assertAllEqual(sp2_out.dense_shape, shape)
731      # Feed SparseTensorValue and fetch sp directly.
732      sp_out = s.run(sp, {
733          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
734      })
735      self.assertAllEqual(sp_out.indices, indices)
736      self.assertAllEqual(sp_out.values, values)
737      self.assertAllEqual(sp_out.dense_shape, shape)
738
739  def testFeedSparsePlaceholder(self):
740    with session.Session() as s:
741      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
742      values = np.array([1.0, 2.0]).astype(np.float32)
743      shape = np.array([7, 9, 2]).astype(np.int64)
744      sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1')
745      sp_indices = array_ops.identity(sp.indices)
746      sp_values = array_ops.identity(sp.values)
747      sp_shape = array_ops.identity(sp.dense_shape)
748      sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
749      # Feed with tuple
750      indices_out, values_out, shape_out = s.run(
751          [sp_indices, sp_values, sp_shape], {
752              sp: (indices, values, shape)
753          })
754      self.assertAllEqual(indices_out, indices)
755      self.assertAllEqual(values_out, values)
756      self.assertAllEqual(shape_out, shape)
757      # Feed with SparseTensorValue
758      indices_out, values_out, shape_out = s.run(
759          [sp_indices, sp_values, sp_shape], {
760              sp: sparse_tensor.SparseTensorValue(indices, values, shape)
761          })
762      self.assertAllEqual(indices_out, indices)
763      self.assertAllEqual(values_out, values)
764      self.assertAllEqual(shape_out, shape)
765      # Feed with SparseTensorValue, fetch SparseTensorValue
766      sp2_out = s.run(sp2, {
767          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
768      })
769      self.assertAllEqual(sp2_out.indices, indices)
770      self.assertAllEqual(sp2_out.values, values)
771      self.assertAllEqual(sp2_out.dense_shape, shape)
772
773  def testFeedSparsePlaceholderPartialShape(self):
774    with session.Session() as s:
775      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
776      values = np.array([1.0, 2.0]).astype(np.float32)
777      shape = np.array([7, 9, 2]).astype(np.int64)
778      sp = array_ops.sparse_placeholder(
779          shape=[None, 9, 2], dtype=np.float32, name='placeholder1')
780      sp_indices = array_ops.identity(sp.indices)
781      sp_values = array_ops.identity(sp.values)
782      sp_shape = array_ops.identity(sp.dense_shape)
783      sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
784      # Feed with tuple
785      indices_out, values_out, shape_out = s.run(
786          [sp_indices, sp_values, sp_shape], {
787              sp: (indices, values, shape)
788          })
789      self.assertAllEqual(indices_out, indices)
790      self.assertAllEqual(values_out, values)
791      self.assertAllEqual(shape_out, shape)
792      # Feed with SparseTensorValue
793      indices_out, values_out, shape_out = s.run(
794          [sp_indices, sp_values, sp_shape], {
795              sp: sparse_tensor.SparseTensorValue(indices, values, shape)
796          })
797      self.assertAllEqual(indices_out, indices)
798      self.assertAllEqual(values_out, values)
799      self.assertAllEqual(shape_out, shape)
800      # Feed with SparseTensorValue, fetch SparseTensorValue
801      sp2_out = s.run(sp2, {
802          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
803      })
804      self.assertAllEqual(sp2_out.indices, indices)
805      self.assertAllEqual(sp2_out.values, values)
806      self.assertAllEqual(sp2_out.dense_shape, shape)
807
808  def testFeedSparsePlaceholderConstantShape(self):
809    with session.Session() as s:
810      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
811      values = np.array([1.0, 2.0]).astype(np.float32)
812      shape = np.array([7, 9, 2]).astype(np.int64)
813      sp = array_ops.sparse_placeholder(
814          dtype=np.float32, shape=shape, name='placeholder1')
815      self.assertAllEqual(sp.dense_shape.eval(session=s), shape)
816      self.assertAllEqual(tensor_util.constant_value(sp.dense_shape), shape)
817      sp_indices = array_ops.identity(sp.indices)
818      sp_values = array_ops.identity(sp.values)
819      sp_shape = array_ops.identity(sp.dense_shape)
820      # Feed with tuple
821      indices_out, values_out, shape_out = s.run(
822          [sp_indices, sp_values, sp_shape], {
823              sp: (indices, values)
824          })
825      self.assertAllEqual(indices_out, indices)
826      self.assertAllEqual(values_out, values)
827      self.assertAllEqual(shape_out, shape)
828
829  def testFetchIndexedSlices(self):
830    with session.Session() as s:
831      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
832      values = np.array([1.0, 2.0]).astype(np.float32)
833      dense_shape = np.array([7, 9, 2]).astype(np.int64)
834      ind = ops.IndexedSlices(
835          constant_op.constant(values), constant_op.constant(indices),
836          constant_op.constant(dense_shape))
837      # Single fetch, use as tuple
838      ind_out = s.run(ind)
839      values_out, indices_out, dense_shape_out = ind_out
840      self.assertAllEqual(values_out, values)
841      self.assertAllEqual(indices_out, indices)
842      self.assertAllEqual(dense_shape_out, dense_shape)
843      # Single fetch, use as IndexedSlicesValue
844      ind_out = s.run(ind)
845      self.assertAllEqual(ind_out.values, values)
846      self.assertAllEqual(ind_out.indices, indices)
847      self.assertAllEqual(ind_out.dense_shape, dense_shape)
848      # Tuple fetch, use as tuple
849      values_out, indices_out, dense_shape_out = s.run(ind)
850      self.assertAllEqual(values_out, values)
851      self.assertAllEqual(indices_out, indices)
852      self.assertAllEqual(dense_shape_out, dense_shape)
853      # List fetch, use as tuple
854      (values_out, indices_out, dense_shape_out), = s.run([ind])
855      self.assertAllEqual(values_out, values)
856      self.assertAllEqual(indices_out, indices)
857      self.assertAllEqual(dense_shape_out, dense_shape)
858      # List fetch, use as IndexedSlicesValue
859      ind_out, = s.run([ind])
860      self.assertAllEqual(ind_out.values, values)
861      self.assertAllEqual(ind_out.indices, indices)
862      self.assertAllEqual(ind_out.dense_shape, dense_shape)
863
864  def testFeedIndexedSlices(self):
865    with session.Session() as s:
866      values = np.array([1.0, 2.0]).astype(np.float32)
867      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
868      dense_shape = np.array([7, 9, 2]).astype(np.int64)
869      ind = ops.IndexedSlices(
870          array_ops.placeholder(dtype=np.float32, shape=(2,)),
871          array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
872          array_ops.placeholder(dtype=np.int64, shape=(3,)),
873      )
874      ind_values = array_ops.identity(ind.values)
875      ind_indices = array_ops.identity(ind.indices)
876      ind_dense_shape = array_ops.identity(ind.dense_shape)
877      ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape)
878      # Feed with tuple
879      values_out, indices_out, dense_shape_out = s.run(
880          [ind_values, ind_indices, ind_dense_shape], {
881              ind: (values, indices, dense_shape)
882          })
883      self.assertAllEqual(values_out, values)
884      self.assertAllEqual(indices_out, indices)
885      self.assertAllEqual(dense_shape_out, dense_shape)
886      # Feed with IndexedSlicesValue
887      values_out, indices_out, dense_shape_out = s.run(
888          [ind_values, ind_indices, ind_dense_shape], {
889              ind: ops.IndexedSlicesValue(values, indices, dense_shape)
890          })
891      self.assertAllEqual(values_out, values)
892      self.assertAllEqual(indices_out, indices)
893      self.assertAllEqual(dense_shape_out, dense_shape)
894      # Feed with IndexedSlicesValue, fetch IndexedSlicesValue
895      ind2_out = s.run(ind2, {
896          ind: ops.IndexedSlicesValue(values, indices, dense_shape)
897      })
898      self.assertAllEqual(ind2_out.values, values)
899      self.assertAllEqual(ind2_out.indices, indices)
900      self.assertAllEqual(ind2_out.dense_shape, dense_shape)
901
902  def testFetchIndexedSlicesWithoutDenseShape(self):
903    with session.Session() as s:
904      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
905      values = np.array([1.0, 2.0]).astype(np.float32)
906      dense_shape = None
907      ind = ops.IndexedSlices(
908          constant_op.constant(values), constant_op.constant(indices), None)
909      # Single fetch, use as tuple
910      ind_out = s.run(ind)
911      values_out, indices_out, dense_shape_out = ind_out
912      self.assertAllEqual(values_out, values)
913      self.assertAllEqual(indices_out, indices)
914      self.assertAllEqual(dense_shape_out, dense_shape)
915      # Single fetch, use as IndexedSlicesValue
916      ind_out = s.run(ind)
917      self.assertAllEqual(ind_out.values, values)
918      self.assertAllEqual(ind_out.indices, indices)
919      self.assertAllEqual(ind_out.dense_shape, dense_shape)
920      # Tuple fetch, use as tuple
921      values_out, indices_out, dense_shape_out = s.run(ind)
922      self.assertAllEqual(values_out, values)
923      self.assertAllEqual(indices_out, indices)
924      self.assertAllEqual(dense_shape_out, dense_shape)
925      # List fetch, use as tuple
926      (values_out, indices_out, dense_shape_out), = s.run([ind])
927      self.assertAllEqual(values_out, values)
928      self.assertAllEqual(indices_out, indices)
929      self.assertAllEqual(dense_shape_out, dense_shape)
930      # List fetch, use as IndexedSlicesValue
931      ind_out, = s.run([ind])
932      self.assertAllEqual(ind_out.values, values)
933      self.assertAllEqual(ind_out.indices, indices)
934      self.assertAllEqual(ind_out.dense_shape, dense_shape)
935
936  def testFeedIndexedSlicesWithoutDenseShape(self):
937    with session.Session() as s:
938      values = np.array([1.0, 2.0]).astype(np.float32)
939      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
940      dense_shape = None
941      ind = ops.IndexedSlices(
942          array_ops.placeholder(dtype=np.float32, shape=(2,)),
943          array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None)
944      ind_values = array_ops.identity(ind.values)
945      ind_indices = array_ops.identity(ind.indices)
946      ind2 = ops.IndexedSlices(ind_values, ind_indices)
947      # Feed with tuple
948      values_out, indices_out = s.run([ind_values, ind_indices], {
949          ind: (values, indices)
950      })
951      self.assertAllEqual(values_out, values)
952      self.assertAllEqual(indices_out, indices)
953      # Feed with IndexedSlicesValue
954      values_out, indices_out = s.run([ind_values, ind_indices], {
955          ind: ops.IndexedSlicesValue(values, indices, dense_shape)
956      })
957      self.assertAllEqual(values_out, values)
958      self.assertAllEqual(indices_out, indices)
959      # Feed with IndexedSlicesValue, fetch IndexedSlicesValue
960      ind2_out = s.run(ind2, {
961          ind: ops.IndexedSlicesValue(values, indices, dense_shape)
962      })
963      self.assertAllEqual(ind2_out.values, values)
964      self.assertAllEqual(ind2_out.indices, indices)
965      self.assertAllEqual(ind2_out.dense_shape, dense_shape)
966
967  def testExtendWithStatelessOperations(self):
968    with session.Session() as s:
969      a = constant_op.constant(1.0, shape=[1, 2])
970      b = constant_op.constant(2.0, shape=[2, 3])
971      c = math_ops.matmul(a, b)
972      c_val = s.run(c)
973      self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
974      d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
975      e = math_ops.matmul(c, d)
976      # Extend will happen here.
977      e_val = s.run(e)
978      self.assertAllEqual([[24.0]], e_val)
979
980  def testExtendWithStatefulOperations(self):
981    with session.Session() as s:
982      a = constant_op.constant(1.0, shape=[1, 2])
983      b = constant_op.constant(2.0, shape=[2, 3])
984      c = math_ops.matmul(a, b)
985      v = variables.Variable(c, name='testExtendWithStatefulOperations_v')
986      v.initializer.run()
987      v_val = v.eval()
988      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
989      d = constant_op.constant(3.0, shape=[2, 3])
990      e = math_ops.matmul(a, d)
991      assign_e_to_v = state_ops.assign(v, e)
992      # Extend will happen here.
993      e_val = e.eval()
994      self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
995      v_val = v.eval()
996      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
997      s.run(assign_e_to_v)
998      v_val = v.eval()
999      self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
1000
1001  def testExtendWithGroupBy(self):
1002    with session.Session() as s:
1003      a = constant_op.constant(1.0, shape=[1, 2])
1004      p = variables.Variable(a, name='testExtendWithGroupBy_p')
1005      a_val = a.eval()  # Force an Extend after this op.
1006      self.assertAllEqual([[1.0, 1.0]], a_val)
1007
1008      b = constant_op.constant(2.0, shape=[1, 2])
1009      q = variables.Variable(b, name='testExtendWithGroupBy_q')
1010      # Extend will happen here.
1011      init = control_flow_ops.group(p.initializer, q.initializer)
1012      s.run(init)
1013      p_val, q_val = s.run([p, q])
1014
1015      self.assertAllEqual([[1.0, 1.0]], p_val)
1016      self.assertAllEqual([[2.0, 2.0]], q_val)
1017
1018  def testTensorGetMethod(self):
1019    with session.Session():
1020      a = constant_op.constant(1.0, shape=[1, 2])
1021      b = constant_op.constant(2.0, shape=[2, 3])
1022      c = math_ops.matmul(a, b)
1023
1024      c_val = c.eval()
1025      self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
1026
1027      fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]})
1028      self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val)
1029
1030  @test_util.run_v1_only('b/120545219')
1031  def testOperationRunMethod(self):
1032    with session.Session():
1033      a = constant_op.constant(1.0, shape=[1, 2])
1034      b = constant_op.constant(2.0, shape=[1, 2], name='b')
1035      v = variables.VariableV1(a, a.dtype)
1036      assign_a_to_v = state_ops.assign(v, a)
1037
1038      assign_a_to_v.eval()
1039
1040      v_val = v.eval()
1041      self.assertAllEqual([[1.0, 1.0]], v_val)
1042
1043      assign_b_to_v = state_ops.assign(v, b)
1044
1045      assign_b_to_v.eval()
1046      v_val = v.eval()
1047      self.assertAllEqual([[2.0, 2.0]], v_val)
1048
1049      assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]})
1050      v_val = v.eval()
1051      self.assertAllEqual([[3.0, 3.0]], v_val)
1052
1053  def testDefaultGraph(self):
1054    with session.Session() as s:
1055      self.assertEqual(ops.get_default_graph(), s.graph)
1056      a = constant_op.constant(1.0, shape=[1, 2])
1057      b = constant_op.constant(2.0, shape=[2, 3])
1058      self.assertEqual(ops.get_default_graph(), a.graph)
1059      self.assertEqual(ops.get_default_graph(), b.graph)
1060      c = math_ops.matmul(a, b)
1061      v = variables.Variable(c, name='testDefaultGraph_v')
1062      v.initializer.run()
1063      v_val = v.eval()
1064      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1065      d = constant_op.constant(3.0, shape=[2, 3])
1066      e = math_ops.matmul(a, d)
1067      assign_e_to_v = state_ops.assign(v, e)
1068      e_val = e.eval()
1069      self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
1070      v_val = v.eval()
1071      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1072      s.run(assign_e_to_v)
1073      v_val = v.eval()
1074      self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
1075      self.assertEqual(ops.get_default_graph(), s.graph)
1076
1077  def _testDefaultGraphInThread(self, constructed_event, continue_event, i):
1078    with session.Session() as s:
1079      self.assertEqual(ops.get_default_graph(), s.graph)
1080      a = constant_op.constant(1.0, shape=[1, 2])
1081      b = constant_op.constant(2.0, shape=[2, 3])
1082      c = math_ops.matmul(a, b)
1083      v = variables.Variable(c, name='var_%d' % i)
1084
1085      # Block here until all threads have constructed their graph.
1086      constructed_event.set()
1087      continue_event.wait()
1088
1089      assign_c_to_v = state_ops.assign(v, c)
1090      v.initializer.run()
1091      assign_c_to_v.eval()
1092      v_val = v.eval()
1093      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1094      d = constant_op.constant(3.0, shape=[2, 3])
1095      e = math_ops.matmul(a, d)
1096      assign_e_to_v = state_ops.assign(v, e)
1097      e_val = e.eval()
1098      self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
1099      v_val = v.eval()
1100      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1101      s.run(assign_e_to_v)
1102      v_val = v.eval()
1103      self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
1104      self.assertEqual(ops.get_default_graph(), s.graph)
1105
1106  def testDefaultGraphWithThreads(self):
1107    # Fork ten threads that use their thread-local default graph.
1108    threads = []
1109    constructed_events = [threading.Event() for _ in range(10)]
1110    continue_event = threading.Event()
1111    for i, constructed_event in enumerate(constructed_events):
1112      t = self.checkedThread(
1113          target=self._testDefaultGraphInThread,
1114          args=(constructed_event, continue_event, i))
1115      threads.append(t)
1116    for t in threads:
1117      t.start()
1118    for constructed_event in constructed_events:
1119      constructed_event.wait()
1120    continue_event.set()
1121    for t in threads:
1122      t.join()
1123
1124  def testParallelRun(self):
1125    with session.Session() as sess:
1126      c = constant_op.constant(5.0)
1127      ev = threading.Event()
1128
1129      def run_step():
1130        ev.wait()
1131        val = c.eval(session=sess)
1132        self.assertEqual(val, 5.0)
1133
1134      threads = [self.checkedThread(target=run_step) for _ in range(100)]
1135      for t in threads:
1136        t.start()
1137      ev.set()
1138      for t in threads:
1139        t.join()
1140
1141  @staticmethod
1142  def _build_graph():
1143    time.sleep(random.random() * 0.1)
1144    # Do some graph construction. Try to exercise non-trivial paths.
1145    graph = ops.get_default_graph()
1146    gdef = None
1147    for _ in range(10):
1148      x = array_ops.placeholder(dtype=dtypes.float32)
1149      with ops.colocate_with(x):
1150        y = array_ops.placeholder(dtype=dtypes.float32)
1151      with ops.device('/cpu:0'):
1152        z = control_flow_ops.while_loop(
1153            lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
1154      with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
1155        gradients_impl.gradients(z, [x, y])
1156        if gdef is None:
1157          gdef = graph.as_graph_def()
1158        else:
1159          importer.import_graph_def(gdef, name='import')
1160
1161  @test_util.run_v1_only('b/120545219')
1162  def testParallelRunAndSingleBuild(self):
1163    with session.Session() as sess:
1164      c = constant_op.constant(5.0)
1165      stop = threading.Event()
1166
1167      def run_loop():
1168        while not stop.is_set():
1169          time.sleep(random.random() * 0.1)
1170          self.assertEqual(sess.run(c), 5.0)
1171
1172      threads = [self.checkedThread(target=run_loop) for _ in range(10)]
1173      for t in threads:
1174        t.start()
1175
1176      SessionTest._build_graph()
1177
1178      stop.set()
1179      for t in threads:
1180        t.join()
1181
1182  @test_util.run_v1_only('b/120545219')
1183  def testParallelRunAndParallelBuild(self):
1184    with session.Session() as sess:
1185      c = constant_op.constant(5.0)
1186      stop = threading.Event()
1187
1188      def run_loop():
1189        while not stop.is_set():
1190          time.sleep(random.random() * 0.1)
1191          self.assertEqual(sess.run(c), 5.0)
1192
1193      run_threads = [self.checkedThread(target=run_loop) for _ in range(10)]
1194      for t in run_threads:
1195        t.start()
1196
1197      build_threads = [self.checkedThread(target=SessionTest._build_graph)
1198                       for _ in range(10)]
1199      for t in build_threads:
1200        t.start()
1201      for t in build_threads:
1202        t.join()
1203
1204      # Let the run_threads run until the build threads are finished.
1205      stop.set()
1206      for t in run_threads:
1207        t.join()
1208
1209  def testRunFeedDict(self):
1210    with session.Session() as s:
1211      x = array_ops.zeros([2])
1212
1213      y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)})
1214      self.assertAllEqual(y, 2 * np.ones(2))
1215
1216      y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)})
1217      self.assertAllEqual(y, 2 * np.ones(2))
1218
1219      y = s.run(2 * x, feed_dict={x: [1, 1]})
1220      assert (y == 2 * np.ones(2)).all()
1221
1222      # Test nested tuple keys
1223      z = (((array_ops.zeros([2]),),), array_ops.zeros([2]),
1224           (array_ops.zeros([2]),))
1225      result = [z[0][0][0] * 2, z[1] * 2, z[2][0] * 2]
1226      values = (((np.array([1, 1]),),), np.array([2, 2]), (np.array([3, 3]),))
1227      result_value = s.run(result, feed_dict={z: values})
1228      self.assertAllEqual(result_value[0], 2 * np.ones(2))
1229      self.assertAllEqual(result_value[1], 2 * np.array([2, 2]))
1230      self.assertAllEqual(result_value[2], 2 * np.array([3, 3]))
1231
1232  def testGraphDef(self):
1233    with session.Session() as sess:
1234      self.assertProtoEquals('versions { producer: %d min_consumer: %d }' %
1235                             (versions.GRAPH_DEF_VERSION,
1236                              versions.GRAPH_DEF_VERSION_MIN_CONSUMER),
1237                             sess.graph_def)
1238      c = constant_op.constant(5.0, name='c')
1239      self.assertEquals(len(sess.graph_def.node), 1)
1240      d = constant_op.constant(6.0, name='d')
1241      self.assertEquals(len(sess.graph_def.node), 2)
1242      self.assertAllEqual(c.eval(), 5.0)
1243      self.assertAllEqual(d.eval(), 6.0)
1244      e = constant_op.constant(7.0, name='e')
1245      self.assertEquals(len(sess.graph_def.node), 3)
1246      self.assertAllEqual(e.eval(), 7.0)
1247
1248  def testUseAfterClose(self):
1249    with session.Session() as sess:
1250      c = constant_op.constant(5.0)
1251      self.assertAllEqual(sess.run(c), 5.0)
1252    with self.assertRaisesWithPredicateMatch(
1253        RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)):
1254      sess.run(c)
1255
1256  def testUseAfterCloseConcurrent(self):
1257    with session.Session() as sess:
1258      c = constant_op.constant(5.0)
1259      self.assertAllEqual(sess.run(c), 5.0)
1260
1261      def update_thread():
1262        with self.assertRaisesWithPredicateMatch(
1263            RuntimeError,
1264            lambda e: 'Attempted to use a closed Session.' in str(e)):
1265          while True:
1266            sess.run(c)
1267
1268      t = threading.Thread(target=update_thread)
1269      t.start()
1270      time.sleep(0.1)
1271      sess.close()
1272      t.join()
1273
1274  def testUseEmptyGraph(self):
1275    with session.Session() as sess:
1276      with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'):
1277        sess.run([])
1278      with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'):
1279        sess.run(())
1280      with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'):
1281        sess.run({})
1282
1283  @test_util.run_v1_only('b/120545219')
1284  def testNotEntered(self):
1285    # pylint: disable=protected-access
1286    self.assertEqual(ops._default_session_stack.get_default(), None)
1287    # pylint: enable=protected-access
1288    with ops.device('/cpu:0'):
1289      sess = session.Session()
1290      c_1 = constant_op.constant(5.0)
1291      with sess.graph.as_default():
1292        c_2 = constant_op.constant(5.0)
1293      self.assertEqual(c_1.graph, c_2.graph)
1294      self.assertEqual(sess.run(c_2), 5.0)
1295      with self.assertRaisesWithPredicateMatch(
1296          ValueError, lambda e: 'No default session is registered.' in str(e)):
1297        c_2.eval()
1298
1299  @test_util.run_v1_only('b/120545219')
1300  def testInteractive(self):
1301    with ops.device('/cpu:0'):
1302      sess = session.InteractiveSession()
1303      a = constant_op.constant(1.0, shape=[1, 2])
1304      b = constant_op.constant(2.0, shape=[2, 3])
1305      c = math_ops.matmul(a, b)
1306      self.assertAllEqual([[4.0, 4.0, 4.0]], c.eval())
1307      d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
1308      e = math_ops.matmul(c, d)
1309      self.assertAllEqual([[24.0]], e.eval())
1310      sess.close()
1311
1312  @test_util.run_v1_only('b/120545219')
1313  def testMultipleInteractiveSessionsWarning(self):
1314    # Reinitialize the global state to ensure that the expected warnings will
1315    # be emitted.
1316    session.InteractiveSession._active_session_count = 0  # pylint: disable=protected-access
1317
1318    sess = session.InteractiveSession()
1319    sess.run(constant_op.constant(4.0))  # Run so that the session is "opened".
1320    sess.close()
1321    # Opening and closing interactive sessions serially should not warn.
1322    with warnings.catch_warnings(record=True) as w:
1323      sess = session.InteractiveSession()
1324      sess.close()
1325    self.assertEqual(0, len(w))
1326
1327    with warnings.catch_warnings(record=True) as w:
1328      sess = session.InteractiveSession()
1329    self.assertEqual(0, len(w))
1330    with warnings.catch_warnings(record=True) as w:
1331      sess2 = session.InteractiveSession()
1332    self.assertEqual(1, len(w))
1333    self.assertTrue('An interactive session is already active. This can cause '
1334                    'out-of-memory errors in some cases. You must explicitly '
1335                    'call `InteractiveSession.close()` to release resources '
1336                    'held by the other session(s).' in str(w[0].message))
1337    sess2.close()
1338    sess.close()
1339
1340  @test_util.run_v1_only('b/120545219')
1341  def testInteractivePlacePrunedGraph(self):
1342    sess = session.InteractiveSession()
1343
1344    # Build a graph that has a bad op in it (no kernel).
1345    #
1346    # This test currently does not link in any GPU kernels,
1347    # which is why placing this is invalid.  If at some point
1348    # GPU kernels are added to this test, some other different
1349    # op / device combo should be chosen.
1350    with ops.device('/device:GPU:0'):
1351      a = constant_op.constant(1.0, shape=[1, 2])
1352
1353    b = constant_op.constant(1.0, shape=[1, 2])
1354
1355    # Only run the valid op, this should work.
1356    b.eval()
1357
1358    with self.assertRaises(errors.InvalidArgumentError):
1359      a.eval()
1360    sess.close()
1361
1362  @test_util.run_v1_only('b/120545219')
1363  def testDefaultSessionPlacePrunedGraph(self):
1364    sess = session.Session()
1365
1366    # Build a graph that has a bad op in it (no kernel).
1367    #
1368    # This test currently does not link in any GPU kernels,
1369    # which is why placing this is invalid.  If at some point
1370    # GPU kernels are added to this test, some other different
1371    # op / device combo should be chosen.
1372    with ops.device('/device:GPU:0'):
1373      _ = constant_op.constant(1.0, shape=[1, 2])
1374
1375    b = constant_op.constant(1.0, shape=[1, 2])
1376
1377    with self.assertRaises(errors.InvalidArgumentError):
1378      # Even though we don't run the bad op, we place the entire
1379      # graph, which should fail with a non-interactive session.
1380      sess.run(b)
1381
1382    sess.close()
1383
1384  def testSharedGraph(self):
1385    with ops.Graph().as_default() as g, ops.device('/cpu:0'):
1386      a = constant_op.constant(1.0, shape=[1, 2])
1387      b = constant_op.constant(2.0, shape=[2, 3])
1388      c = math_ops.matmul(a, b)
1389
1390    with session.Session(graph=g) as sess1:
1391      with session.Session(graph=g) as sess2:
1392        self.assertAllEqual(sess1.run(c), sess2.run(c))
1393
1394  def testDuplicatedInputs(self):
1395    with session.Session() as sess:
1396      a = constant_op.constant(1.0, shape=[1, 2])
1397      b = constant_op.constant(2.0, shape=[1, 3])
1398      a_val, b_val, a2_val = sess.run([a, b, a])
1399      self.assertAllEqual(a_val, [[1.0, 1.0]])
1400      self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
1401      self.assertAllEqual(a2_val, [[1.0, 1.0]])
1402
1403  def testFeedAndFetch(self):
1404    with session.Session() as sess:
1405      for dtype in [
1406          dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
1407          dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool,
1408          dtypes.complex64, dtypes.complex128
1409      ]:
1410        for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
1411          np_dtype = dtype.as_numpy_dtype
1412
1413          feed_t = array_ops.placeholder(dtype=dtype, shape=shape)
1414          out_t = array_ops.identity(feed_t)
1415
1416          np_array = np.random.randint(-10, 10, shape)
1417
1418          if dtype == dtypes.bool:
1419            np_array = np_array > 0
1420          elif dtype == dtypes.complex64:
1421            np_array = np.sqrt(np_array.astype(np_dtype))
1422          elif dtype == dtypes.complex64:
1423            np_array = np.sqrt(np_array.astype(np_dtype))
1424          else:
1425            np_array = np_array.astype(np_dtype)
1426
1427          self.assertAllEqual(np_array,
1428                              sess.run(out_t, feed_dict={
1429                                  feed_t: np_array
1430                              }))
1431          # Check that we can also get the feed back.
1432          self.assertAllEqual(np_array,
1433                              sess.run(feed_t, feed_dict={
1434                                  feed_t: np_array
1435                              }))
1436          # Also check that we can get both back.
1437          out_v, feed_v = sess.run(
1438              [out_t, feed_t], feed_dict={
1439                  feed_t: np_array
1440              })
1441          self.assertAllEqual(np_array, out_v)
1442          self.assertAllEqual(np_array, feed_v)
1443
1444          feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t])
1445          out_v, feed_v = feed_fetch_runner(np_array)
1446          self.assertAllEqual(np_array, out_v)
1447          self.assertAllEqual(np_array, feed_v)
1448
1449  def testMakeCallableOnTensorWithRunOptions(self):
1450    with session.Session() as sess:
1451      a = constant_op.constant(42.0)
1452      tensor_runner = sess.make_callable(a, accept_options=True)
1453      run_options = config_pb2.RunOptions(
1454          trace_level=config_pb2.RunOptions.FULL_TRACE)
1455      run_metadata = config_pb2.RunMetadata()
1456      self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
1457      res = tensor_runner(options=run_options, run_metadata=run_metadata)
1458      self.assertEqual(42.0, res)
1459      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1460
1461  def testMakeCallableOnOperationWithRunOptions(self):
1462    with session.Session() as sess:
1463      a = variables.Variable(42.0)
1464      b = state_ops.assign_add(a, 1.0)
1465      sess.run(a.initializer)
1466      tensor_runner = sess.make_callable(b.op, accept_options=True)
1467      run_options = config_pb2.RunOptions(
1468          trace_level=config_pb2.RunOptions.FULL_TRACE)
1469      run_metadata = config_pb2.RunMetadata()
1470      self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
1471      tensor_runner(options=run_options, run_metadata=run_metadata)
1472      self.assertEqual(43.0, sess.run(a))
1473      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1474
1475  def testMakeCallableWithFeedListAndRunOptions(self):
1476    with session.Session() as sess:
1477      ph = array_ops.placeholder(dtypes.float32)
1478      a = math_ops.add(ph, 1.0)
1479      tensor_runner = sess.make_callable(
1480          a, feed_list=[ph.name], accept_options=True)
1481      run_options = config_pb2.RunOptions(
1482          trace_level=config_pb2.RunOptions.FULL_TRACE)
1483      run_metadata = config_pb2.RunMetadata()
1484      self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
1485      self.assertAllClose(42.0,
1486                          tensor_runner(
1487                              41.0,
1488                              options=run_options,
1489                              run_metadata=run_metadata))
1490      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1491
1492  def testOptimizedMakeCallable(self):
1493    with session.Session() as sess:
1494      ph = array_ops.placeholder(dtypes.float32)
1495      a = math_ops.add(ph, 1.0)
1496      callable_opts = config_pb2.CallableOptions()
1497      callable_opts.feed.append(ph.name)
1498      callable_opts.fetch.append(a.name)
1499      for _ in range(3):
1500        callable_fn = sess._make_callable_from_options(callable_opts)
1501        for _ in range(5):
1502          self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32)))
1503
1504  def testOptimizedMakeCallableWithRunMetadata(self):
1505    with session.Session() as sess:
1506      ph = array_ops.placeholder(dtypes.float32)
1507      a = math_ops.add(ph, 1.0)
1508      callable_opts = config_pb2.CallableOptions()
1509      callable_opts.feed.append(ph.name)
1510      callable_opts.fetch.append(a.name)
1511      callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
1512      callable_fn = sess._make_callable_from_options(callable_opts)
1513      run_metadata = config_pb2.RunMetadata()
1514      self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32),
1515                                          run_metadata=run_metadata))
1516      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1517
1518  def testFeedError(self):
1519    with session.Session() as sess:
1520      feed_t = array_ops.placeholder(dtype=dtypes.float32)
1521      out_t = array_ops.identity(feed_t)
1522      feed_val = constant_op.constant(5.0)
1523      with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'):
1524        sess.run(out_t, feed_dict={feed_t: feed_val})
1525      with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'):
1526        out_t.eval(feed_dict={feed_t: feed_val})
1527      with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'):
1528        out_t.op.run(feed_dict={feed_t: feed_val})
1529
1530  def testFeedPrecisionLossError(self):
1531    with session.Session() as sess:
1532      largest_int64 = np.iinfo(np.int64).max
1533
1534      feed_int_implicit_int32 = constant_op.constant(1)
1535      feed_int_explicit_int32 = constant_op.constant(1, dtype=dtypes.int32)
1536
1537      out_t = constant_op.constant(1.0)
1538
1539      with self.assertRaisesRegexp(TypeError,
1540                                   'is not compatible with Tensor type'):
1541        sess.run(out_t, feed_dict={feed_int_implicit_int32: largest_int64})
1542      with self.assertRaisesRegexp(TypeError,
1543                                   'is not compatible with Tensor type'):
1544        sess.run(out_t, feed_dict={feed_int_explicit_int32: largest_int64})
1545
1546  def testStringFetch(self):
1547    with session.Session():
1548      for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
1549        size = 1
1550        for s in shape:
1551          size *= s
1552        c_list = np.array(
1553            [compat.as_bytes(str(i)) for i in xrange(size)],
1554            dtype=np.object).reshape(shape) if size > 0 else []
1555        c = constant_op.constant(c_list)
1556        self.assertAllEqual(c.eval(), c_list)
1557
1558  def testStringFeed(self):
1559    with session.Session() as sess:
1560      for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
1561        size = 1
1562        for s in shape:
1563          size *= s
1564        c_list = np.array(
1565            [compat.as_bytes(str(i)) for i in xrange(size)],
1566            dtype=np.object).reshape(shape)
1567        feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape)
1568        c = array_ops.identity(feed_t)
1569        self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list)
1570        self.assertAllEqual(
1571            sess.run(feed_t, feed_dict={
1572                feed_t: c_list
1573            }), c_list)
1574        c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list})
1575        self.assertAllEqual(c_v, c_list)
1576        self.assertAllEqual(feed_v, c_list)
1577
1578  def testStringFeedWithNullCharacters(self):
1579    with session.Session():
1580      c_list = [b'\n\x01\x00', b'\n\x00\x01']
1581      feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2])
1582      c = array_ops.identity(feed_t)
1583      out = c.eval(feed_dict={feed_t: c_list})
1584      self.assertEqual(c_list[0], out[0])
1585      self.assertEqual(c_list[1], out[1])
1586
1587  def testStringFeedWithUnicode(self):
1588    with session.Session():
1589      c_list = [
1590          u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode',
1591          u'\U0001f60e deal with it'
1592      ]
1593      feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)])
1594      c = array_ops.identity(feed_t)
1595
1596      out = c.eval(feed_dict={feed_t: c_list})
1597      for i in range(len(c_list)):
1598        self.assertEqual(c_list[i], out[i].decode('utf-8'))
1599
1600      out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object)})
1601      for i in range(len(c_list)):
1602        self.assertEqual(c_list[i], out[i].decode('utf-8'))
1603
1604  def testInvalidTargetFails(self):
1605    with self.assertRaisesRegexp(
1606        errors.NotFoundError,
1607        'No session factory registered for the given session options'):
1608      session.Session('INVALID_TARGET')
1609
1610  def testFetchByNameDifferentStringTypes(self):
1611    with session.Session() as sess:
1612      c = constant_op.constant(42.0, name='c')
1613      d = constant_op.constant(43.0, name=u'd')
1614      e = constant_op.constant(44.0, name=b'e')
1615      f = constant_op.constant(45.0, name=r'f')
1616
1617      self.assertTrue(isinstance(c.name, six.text_type))
1618      self.assertTrue(isinstance(d.name, six.text_type))
1619      self.assertTrue(isinstance(e.name, six.text_type))
1620      self.assertTrue(isinstance(f.name, six.text_type))
1621
1622      self.assertEqual(42.0, sess.run('c:0'))
1623      self.assertEqual(42.0, sess.run(u'c:0'))
1624      self.assertEqual(42.0, sess.run(b'c:0'))
1625      self.assertEqual(42.0, sess.run(r'c:0'))
1626
1627      self.assertEqual(43.0, sess.run('d:0'))
1628      self.assertEqual(43.0, sess.run(u'd:0'))
1629      self.assertEqual(43.0, sess.run(b'd:0'))
1630      self.assertEqual(43.0, sess.run(r'd:0'))
1631
1632      self.assertEqual(44.0, sess.run('e:0'))
1633      self.assertEqual(44.0, sess.run(u'e:0'))
1634      self.assertEqual(44.0, sess.run(b'e:0'))
1635      self.assertEqual(44.0, sess.run(r'e:0'))
1636
1637      self.assertEqual(45.0, sess.run('f:0'))
1638      self.assertEqual(45.0, sess.run(u'f:0'))
1639      self.assertEqual(45.0, sess.run(b'f:0'))
1640      self.assertEqual(45.0, sess.run(r'f:0'))
1641
1642  def testIncorrectGraph(self):
1643    with ops.Graph().as_default() as g_1:
1644      c_1 = constant_op.constant(1.0, name='c')
1645
1646    with ops.Graph().as_default() as g_2:
1647      c_2 = constant_op.constant(2.0, name='c')
1648
1649    self.assertEqual('c', c_1.op.name)
1650    self.assertEqual('c', c_2.op.name)
1651
1652    with session.Session(graph=g_1) as sess_1:
1653      self.assertEqual(1.0, sess_1.run(c_1))
1654      with self.assertRaises(ValueError):
1655        sess_1.run(c_2)
1656      with self.assertRaises(ValueError):
1657        sess_1.run(c_2.op)
1658
1659    with session.Session(graph=g_2) as sess_2:
1660      with self.assertRaises(ValueError):
1661        sess_2.run(c_1)
1662      with self.assertRaises(ValueError):
1663        sess_2.run(c_1.op)
1664      self.assertEqual(2.0, sess_2.run(c_2))
1665
1666  def testFeedDictKeyException(self):
1667    with session.Session() as sess:
1668      a = constant_op.constant(1.0, dtypes.float32, name='a')
1669      with self.assertRaisesRegexp(TypeError, 'Cannot interpret feed_dict'):
1670        sess.run(a, feed_dict={'a': [2.0]})
1671
1672  def testPerStepTrace(self):
1673    run_options = config_pb2.RunOptions(
1674        trace_level=config_pb2.RunOptions.FULL_TRACE)
1675    run_metadata = config_pb2.RunMetadata()
1676
1677    with ops.device('/cpu:0'):
1678      with session.Session() as sess:
1679        sess.run(constant_op.constant(1.0))
1680        self.assertTrue(not run_metadata.HasField('step_stats'))
1681
1682        sess.run(constant_op.constant(1.0), run_metadata=run_metadata)
1683        self.assertTrue(not run_metadata.HasField('step_stats'))
1684
1685        sess.run(
1686            constant_op.constant(1.0),
1687            options=run_options,
1688            run_metadata=run_metadata)
1689
1690        self.assertTrue(run_metadata.HasField('step_stats'))
1691        self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
1692
1693  def testRunOptionsRunMetadata(self):
1694    run_options = config_pb2.RunOptions(
1695        trace_level=config_pb2.RunOptions.FULL_TRACE)
1696    run_metadata = config_pb2.RunMetadata()
1697
1698    with ops.device('/cpu:0'):
1699      with session.Session() as sess:
1700        # all combinations are valid
1701        sess.run(constant_op.constant(1.0), options=None, run_metadata=None)
1702        sess.run(
1703            constant_op.constant(1.0), options=None, run_metadata=run_metadata)
1704        self.assertTrue(not run_metadata.HasField('step_stats'))
1705
1706        sess.run(
1707            constant_op.constant(1.0), options=run_options, run_metadata=None)
1708        self.assertTrue(not run_metadata.HasField('step_stats'))
1709
1710        sess.run(
1711            constant_op.constant(1.0),
1712            options=run_options,
1713            run_metadata=run_metadata)
1714
1715        self.assertTrue(run_metadata.HasField('step_stats'))
1716        self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
1717
1718  def testFeedShapeCompatibility(self):
1719    with session.Session() as sess:
1720      some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
1721      new_shape = constant_op.constant([2, 2])
1722      reshaped_tensor = array_ops.reshape(some_tensor, new_shape)
1723
1724      with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'):
1725        sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]})
1726
1727      with self.assertRaisesRegexp(
1728          errors.InvalidArgumentError,
1729          'Input to reshape is a tensor with 4 values, '
1730          'but the requested shape has 21'):
1731        sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]})
1732
1733  def testInferShapesFalse(self):
1734    with ops.Graph().as_default(), ops.device('/cpu:0'):
1735      a = constant_op.constant([[1, 2]])
1736      sess = session.Session()
1737      self.assertFalse('_output_shapes' in sess.graph_def.node[0].attr)
1738      # Avoid lint error regarding 'unused' var a.
1739      self.assertTrue(a == a)
1740
1741  def testInferShapesTrue(self):
1742    config = config_pb2.ConfigProto(
1743        graph_options=config_pb2.GraphOptions(infer_shapes=True))
1744    with ops.Graph().as_default(), ops.device('/cpu:0'):
1745      a = constant_op.constant([[1, 2]])
1746      sess = session.Session(config=config)
1747      self.assertTrue('_output_shapes' in sess.graph_def.node[0].attr)
1748      # Avoid lint error regarding 'unused' var a.
1749      self.assertTrue(a == a)
1750
1751  def testBuildCostModel(self):
1752    run_options = config_pb2.RunOptions()
1753    config = config_pb2.ConfigProto(
1754        allow_soft_placement=True,
1755        graph_options=config_pb2.GraphOptions(build_cost_model=100))
1756    with session.Session(config=config) as sess:
1757      with ops.device('/device:GPU:0'):
1758        a = array_ops.placeholder(dtypes.float32, shape=[])
1759        b = math_ops.add(a, a)
1760        c = array_ops.identity(b)
1761        d = math_ops.multiply(c, c)
1762      for step in xrange(120):
1763        run_metadata = config_pb2.RunMetadata()
1764        sess.run(
1765            d,
1766            feed_dict={a: 1.0},
1767            options=run_options,
1768            run_metadata=run_metadata)
1769        if step == 99:
1770          self.assertTrue(run_metadata.HasField('cost_graph'))
1771        else:
1772          self.assertFalse(run_metadata.HasField('cost_graph'))
1773
1774  def runTestOutputPartitionGraphs(self, sess):
1775    run_options = config_pb2.RunOptions(output_partition_graphs=True)
1776    a = constant_op.constant(1)
1777    run_metadata = config_pb2.RunMetadata()
1778    sess.run(a, options=run_options, run_metadata=run_metadata)
1779    self.assertGreater(len(run_metadata.partition_graphs), 0)
1780    sess.run(a, run_metadata=run_metadata)
1781    self.assertEqual(len(run_metadata.partition_graphs), 0)
1782
1783  @test_util.run_v1_only('b/120545219')
1784  def testOutputPartitionGraphsDirect(self):
1785    self.runTestOutputPartitionGraphs(session.Session())
1786
1787  @test_util.run_v1_only('b/120545219')
1788  def testOutputPartitionGraphsDistributed(self):
1789    server = server_lib.Server.create_local_server()
1790    self.runTestOutputPartitionGraphs(session.Session(server.target))
1791
1792  def testNonInteractiveSessionNesting(self):
1793    sess1 = session.Session()
1794    sess1_controller = sess1.as_default()
1795    sess1_controller.__enter__()
1796
1797    sess2 = session.Session()
1798    sess2_controller = sess2.as_default()
1799    sess2_controller.__enter__()
1800
1801    with self.assertRaisesRegexp(AssertionError, 'Nesting violated'):
1802      sess1_controller.__exit__(None, None, None)
1803
1804    ops._default_session_stack.reset()
1805
1806  def testInteractiveSessionNesting(self):
1807    sess1 = session.InteractiveSession()
1808    sess2 = session.InteractiveSession()
1809    del sess1
1810    del sess2
1811
1812  @test_util.run_v1_only('b/120545219')
1813  def testAsDefault(self):
1814    c = constant_op.constant(37)
1815    sess = session.Session()
1816    with sess.as_default():
1817      self.assertEqual(37, c.eval())
1818
1819    # Ensure that the session remains valid even when it is not captured.
1820    with session.Session().as_default():
1821      self.assertEqual(37, c.eval())
1822
1823  def testReentry(self):
1824    sess = session.Session()
1825    with self.assertRaisesRegexp(RuntimeError, 'not re-entrant'):
1826      with sess:
1827        with sess:
1828          pass
1829
1830  def testInvalidArgument(self):
1831    with self.assertRaisesRegexp(TypeError, 'target must be a string'):
1832      session.Session(37)
1833    with self.assertRaisesRegexp(TypeError, 'config must be a tf.ConfigProto'):
1834      session.Session(config=37)
1835    with self.assertRaisesRegexp(TypeError, 'graph must be a tf.Graph'):
1836      session.Session(graph=37)
1837
1838  @test_util.run_v1_only('b/120545219')
1839  def testTimeoutWithShortOperations(self):
1840    num_epochs = 5
1841    q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()])
1842    enqueue_op = q.enqueue_many(constant_op.constant([1, 2]))
1843
1844    # Use a 10-second timeout, which should be longer than any
1845    # non-blocking enqueue_many op.
1846    config = config_pb2.ConfigProto(operation_timeout_in_ms=10000)
1847    with session.Session(config=config) as sess:
1848      for _ in range(num_epochs):
1849        sess.run(enqueue_op)
1850      self.assertEqual(sess.run(q.size()), num_epochs * 2)
1851
1852  @test_util.run_v1_only('b/120545219')
1853  def testRegisterFetchAndFeedConversionFunctions(self):
1854
1855    class SquaredTensor(object):
1856
1857      def __init__(self, tensor):
1858        self.sq = math_ops.square(tensor)
1859
1860    fetch_fn = lambda squared_tensor: ([squared_tensor.sq], lambda val: val[0])
1861    feed_fn1 = lambda feed, feed_val: [(feed.sq, feed_val)]
1862    feed_fn2 = lambda feed: [feed.sq]
1863
1864    session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
1865                                                      feed_fn1, feed_fn2)
1866    with self.assertRaises(ValueError):
1867      session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
1868                                                        feed_fn1, feed_fn2)
1869    with self.cached_session() as sess:
1870      np1 = np.array([1.0, 1.5, 2.0, 2.5])
1871      np2 = np.array([3.0, 3.5, 4.0, 4.5])
1872      squared_tensor = SquaredTensor(np2)
1873      squared_eval = sess.run(squared_tensor)
1874      self.assertAllClose(np2 * np2, squared_eval)
1875      squared_eval = sess.run(
1876          squared_tensor, feed_dict={
1877              squared_tensor: np1 * np1
1878          })
1879      self.assertAllClose(np1 * np1, squared_eval)
1880      partial_run = sess.partial_run_setup([squared_tensor], [])
1881      squared_eval = sess.partial_run(partial_run, squared_tensor)
1882      self.assertAllClose(np2 * np2, squared_eval)
1883
1884  def testDefaultLogDevicePlacement(self):
1885
1886    class CaptureStderr(str):
1887      """Class to capture stderr from C++ shared library."""
1888
1889      def __enter__(self):
1890        self._esc = compat.as_str('\b')
1891        self._output = compat.as_str('')
1892        self._stderr = sys.stderr
1893        self._fd = self._stderr.fileno()
1894        self._out_pipe, in_pipe = os.pipe()
1895        # Save the original io stream.
1896        self._dup_fd = os.dup(self._fd)
1897        # Replace the original io stream with in pipe.
1898        os.dup2(in_pipe, self._fd)
1899        return self
1900
1901      def __exit__(self, *args):
1902        self._stderr.write(self._esc)
1903        self._stderr.flush()
1904        self.read()
1905        os.close(self._out_pipe)
1906        # Restore the original io stream.
1907        os.dup2(self._dup_fd, self._fd)
1908
1909      def read(self):
1910        while True:
1911          data = os.read(self._out_pipe, 1)
1912          if not data or compat.as_str(data) == self._esc:
1913            break
1914          self._output += compat.as_str(data)
1915
1916      def __str__(self):
1917        return self._output
1918
1919    if context.executing_eagerly():
1920      context.set_log_device_placement(True)
1921      with CaptureStderr() as log:
1922        a = constant_op.constant(1)
1923        b = constant_op.constant(2)
1924        c = a + b
1925    else:
1926      # Passing the config to the server, but not the session should still
1927      # result in logging device placement.
1928      config = config_pb2.ConfigProto(log_device_placement=True)
1929      server = server_lib.Server.create_local_server(config=config)
1930      a = constant_op.constant(1)
1931      b = constant_op.constant(2)
1932      c = a + b
1933      with session.Session(server.target) as sess:
1934        with CaptureStderr() as log:
1935          sess.run(c)
1936
1937    # Ensure that we did log device placement.
1938    self.assertTrue('/replica:0/task:0/device:CPU:0' in str(log), str(log))
1939
1940  @test_util.run_v1_only('b/120545219')
1941  def testLocalMasterSessionTimeout(self):
1942    # Test that the timeout passed in a config to the session works correctly.
1943    config = config_pb2.ConfigProto(operation_timeout_in_ms=1000)
1944    server = server_lib.Server.create_local_server()
1945    q = data_flow_ops.FIFOQueue(1, dtypes.float32)
1946    dequeued_t = q.dequeue()
1947
1948    with session.Session(server.target, config=config) as sess:
1949      # Intentionally do not run any enqueue_ops so that dequeue will block
1950      # until operation_timeout_in_ms.
1951      with self.assertRaises(errors.DeadlineExceededError):
1952        sess.run(dequeued_t)
1953
1954  @test_util.run_v1_only('b/120545219')
1955  def testDefaultServerTimeout(self):
1956    # Test that the default server config timeout gets used when no Session
1957    # config is provided.
1958    config = config_pb2.ConfigProto(operation_timeout_in_ms=1000)
1959    server = server_lib.Server.create_local_server(config=config)
1960    q = data_flow_ops.FIFOQueue(1, dtypes.float32)
1961    dequeued_t = q.dequeue()
1962
1963    with session.Session(server.target) as sess:
1964      # Intentionally do not run any enqueue_ops so that dequeue will block
1965      # until operation_timeout_in_ms.
1966      with self.assertRaises(errors.DeadlineExceededError):
1967        sess.run(dequeued_t)
1968
1969  def runTestBuildGraphError(self, sess):
1970    # Ensure that errors from building the graph get propagated.
1971    data = array_ops.placeholder(dtypes.float32, shape=[])
1972    # pylint: disable=protected-access
1973    enter_1 = gen_control_flow_ops.enter(data, 'foo_1', False)
1974    enter_2 = gen_control_flow_ops.enter(data, 'foo_2', False)
1975    # pylint: enable=protected-access
1976    res = math_ops.add(enter_1, enter_2)
1977    with self.assertRaisesOpError('has inputs from different frames'):
1978      sess.run(res, feed_dict={data: 1.0})
1979
1980  @test_util.run_v1_only('b/120545219')
1981  def testBuildGraphErrorDirect(self):
1982    self.runTestBuildGraphError(session.Session())
1983
1984  @test_util.run_v1_only('b/120545219')
1985  def testBuildGraphErrorDist(self):
1986    server = server_lib.Server.create_local_server()
1987    self.runTestBuildGraphError(session.Session(server.target))
1988
1989  def testDeviceAttributes(self):
1990    attrs = session._DeviceAttributes(
1991        '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000)
1992    self.assertEqual(1337, attrs.memory_limit_bytes)
1993    self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
1994    self.assertEqual('TYPE', attrs.device_type)
1995    self.assertEqual(1000000, attrs.incarnation)
1996    str_repr = '%s' % attrs
1997    self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
1998
1999  def testDeviceAttributesCanonicalization(self):
2000    attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
2001                                      'TYPE', 1337, 1000000)
2002    self.assertEqual(1337, attrs.memory_limit_bytes)
2003    self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
2004    self.assertEqual('TYPE', attrs.device_type)
2005    self.assertEqual(1000000, attrs.incarnation)
2006    str_repr = '%s' % attrs
2007    self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
2008
2009  def runTestAddFunctionToSession(self, target=''):
2010    """Add a function to a session after the graph has already been run."""
2011
2012    @function.Defun(dtypes.float32)
2013    def foo(x):
2014      return x + 1
2015
2016    x = constant_op.constant(1.0)
2017    with session.Session(target=target) as sess:
2018      sess.run(x)
2019      f = foo(x)
2020      result = sess.run(f)
2021      self.assertEqual(result, 2.0)
2022
2023  @test_util.run_v1_only('b/120545219')
2024  def testAddFunctionToSession(self):
2025    self.runTestAddFunctionToSession()
2026
2027  @test_util.run_v1_only('b/120545219')
2028  def testAddFunctionToGrpcSession(self):
2029    server = server_lib.Server.create_local_server()
2030    self.runTestAddFunctionToSession(server.target)
2031
2032  def testOpenAndCloseGrpcSession(self):
2033    server = server_lib.Server.create_local_server()
2034    with session.Session(server.target):
2035      pass
2036
2037  def testOpenAndCloseSession(self):
2038    with session.Session():
2039      pass
2040
2041  @test_util.run_v1_only('b/120545219')
2042  def testAutoConvertAndCheckData(self):
2043    with self.cached_session() as sess:
2044      a = array_ops.placeholder(dtype=dtypes.string)
2045      with self.assertRaisesRegexp(
2046          TypeError, r'Type of feed value 1 with type <(\w+) \'int\'> is not'):
2047        sess.run(a, feed_dict={a: 1})
2048
2049
2050if __name__ == '__main__':
2051  googletest.main()
2052