• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.client.session.Session's ClusterSpec Propagation.
16
17These tests exercise the ClusterSpec Propagation capabilities of distributed
18Sessions.
19"""
20import numpy as np
21
22from tensorflow.core.protobuf import cluster_pb2
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python.client import session
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import math_ops
32# Import resource_variable_ops for the variables-to-tensor implicit conversion.
33from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
34from tensorflow.python.ops import state_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import googletest
37from tensorflow.python.platform import test
38from tensorflow.python.training import server_lib
39
40
41class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
42
43  def testClusterSpecPropagationSimple(self):
44    server1 = server_lib.Server.create_local_server()
45    server2 = server_lib.Server.create_local_server()
46    cluster_def = cluster_pb2.ClusterDef()
47    job = cluster_def.job.add()
48    job.name = 'worker'
49    job.tasks[0] = server1.target[len('grpc://'):]
50    job.tasks[1] = server2.target[len('grpc://'):]
51    config = config_pb2.ConfigProto(cluster_def=cluster_def)
52
53    const = constant_op.constant(17)
54    sess = session.Session(server1.target, config=config)
55    output = self.evaluate(const)
56    self.assertEqual(17, output)
57
58  def testClusterSpecPropagationWorker2Placement(self):
59    server1 = server_lib.Server.create_local_server()
60    server2 = server_lib.Server.create_local_server()
61    cluster_def = cluster_pb2.ClusterDef()
62    job = cluster_def.job.add()
63    job.name = 'worker'
64    job.tasks[0] = server1.target[len('grpc://'):]
65    job.tasks[1] = server2.target[len('grpc://'):]
66    config = config_pb2.ConfigProto(cluster_def=cluster_def)
67
68    with ops.Graph().as_default() as g, ops.device('/job:worker/task:1'):
69      with ops.device('/cpu:0'):
70        const = constant_op.constant(17)
71    sess = session.Session(server1.target, config=config, graph=g)
72    run_options = config_pb2.RunOptions(
73        trace_level=config_pb2.RunOptions.FULL_TRACE)
74    run_metadata = config_pb2.RunMetadata()
75    output = sess.run(const, options=run_options, run_metadata=run_metadata)
76    self.assertEqual(17, output)
77    self.assertEqual(1,
78                     len([
79                         node_stats
80                         for dev_stats in run_metadata.step_stats.dev_stats
81                         for node_stats in dev_stats.node_stats
82                         if '/job:worker/replica:0/task:1/device:CPU:0' ==
83                         dev_stats.device and 'Const' == node_stats.node_name
84                     ]))
85
86  def testClusterSpecPropagationWorker1Placement(self):
87    server1 = server_lib.Server.create_local_server()
88    server2 = server_lib.Server.create_local_server()
89    cluster_def = cluster_pb2.ClusterDef()
90    job = cluster_def.job.add()
91    job.name = 'worker'
92    job.tasks[0] = server1.target[len('grpc://'):]
93    job.tasks[1] = server2.target[len('grpc://'):]
94    config = config_pb2.ConfigProto(cluster_def=cluster_def)
95
96    with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'):
97      const = constant_op.constant(17)
98    with session.Session(server1.target, config=config, graph=g):
99      output = self.evaluate(const)
100    self.assertEqual(17, output)
101
102  def testCanonicalDeviceNames(self):
103    server1 = server_lib.Server.create_local_server()
104    server2 = server_lib.Server.create_local_server()
105    cluster_def = cluster_pb2.ClusterDef()
106    job = cluster_def.job.add()
107    job.name = 'worker'
108    job.tasks[0] = server1.target[len('grpc://'):]
109    job.tasks[1] = server2.target[len('grpc://'):]
110    config = config_pb2.ConfigProto(cluster_def=cluster_def)
111
112    with ops.Graph().as_default() as g, ops.device(
113        '/job:worker/task:1/device:CPU:0'):
114      const = constant_op.constant(17)
115    sess = session.Session(server1.target, config=config, graph=g)
116    run_options = config_pb2.RunOptions(
117        trace_level=config_pb2.RunOptions.FULL_TRACE)
118    run_metadata = config_pb2.RunMetadata()
119    output = sess.run(const, options=run_options, run_metadata=run_metadata)
120    self.assertEqual(17, output)
121    self.assertEqual(1,
122                     len([
123                         node_stats
124                         for dev_stats in run_metadata.step_stats.dev_stats
125                         for node_stats in dev_stats.node_stats
126                         if '/job:worker/replica:0/task:1/device:CPU:0' ==
127                         dev_stats.device and 'Const' == node_stats.node_name
128                     ]))
129
130  def testFullDeviceNames(self):
131    server1 = server_lib.Server.create_local_server()
132    server2 = server_lib.Server.create_local_server()
133    cluster_def = cluster_pb2.ClusterDef()
134    job = cluster_def.job.add()
135    job.name = 'renamed_worker'
136    job.tasks[0] = server1.target[len('grpc://'):]
137    job.tasks[1] = server2.target[len('grpc://'):]
138    config = config_pb2.ConfigProto(cluster_def=cluster_def)
139
140    with ops.Graph().as_default() as g, ops.device(
141        '/job:renamed_worker/replica:0/task:1/device:CPU:0'):
142      const = constant_op.constant(17)
143    sess = session.Session(server1.target, config=config, graph=g)
144    run_options = config_pb2.RunOptions(
145        trace_level=config_pb2.RunOptions.FULL_TRACE)
146    run_metadata = config_pb2.RunMetadata()
147    output = sess.run(const, options=run_options, run_metadata=run_metadata)
148    self.assertEqual(17, output)
149    self.assertEqual(1,
150                     len([
151                         node_stats
152                         for dev_stats in run_metadata.step_stats.dev_stats
153                         for node_stats in dev_stats.node_stats
154                         if '/job:renamed_worker/replica:0/task:1/device:CPU:0'
155                         == dev_stats.device and 'Const' == node_stats.node_name
156                     ]))
157
158  def testMultipleLocalDevices(self):
159    # Note: CPU->CPU transfers have a fast-path in
160    # BaseRemoteRendezvous::SameWorkerRecvDone that means the test doesn't
161    # actually capture the motivating bug unless run on a GPU machine.
162    #
163    # Example error message (before bugfix -- line breaks added because  lint):
164    #
165    # W0718 17:14:41.521534  190121 device_mgr.cc:107] Unknown device:
166    #     /job:worker/replica:0/task:0/device:CPU:0 all devices:
167    #     /job:local/replica:0/task:0/device:GPU:0,
168    #     /job:local/replica:0/task:0/device:GPU:0,
169    #     /job:local/replica:0/task:0/cpu:1, CPU:0, GPU:0,
170    #     /job:local/replica:0/task:0/device:CPU:1,
171    #     /job:local/replica:0/task:0/device:CPU:0, CPU:1,
172    #     /job:local/replica:0/task:0/cpu:0
173    server_config = config_pb2.ConfigProto(device_count={'CPU': 2})
174    server1 = server_lib.Server.create_local_server(config=server_config)
175    server2 = server_lib.Server.create_local_server(config=server_config)
176    cluster_def = cluster_pb2.ClusterDef()
177    job = cluster_def.job.add()
178    job.name = 'worker'
179    job.tasks[0] = server1.target[len('grpc://'):]
180    job.tasks[1] = server2.target[len('grpc://'):]
181    config = config_pb2.ConfigProto(cluster_def=cluster_def)
182
183    with ops.Graph().as_default() as g:
184      with ops.device('/job:worker/task:1/cpu:1'):
185        input1 = constant_op.constant(17, dtypes.float32)
186      with ops.device('/job:worker/task:0/cpu:1'):
187        input2 = constant_op.constant(3, dtypes.float32)
188      with ops.device('/job:worker/task:1/cpu:0'):
189        sum1 = input1 + input2
190
191      if test.is_gpu_available():
192        device_str = '/job:worker/task:0/device:GPU:0'
193      else:
194        device_str = '/job:worker/task:0/cpu:1'
195      with ops.device(device_str):
196        sum2 = input2 + input1
197
198      with ops.device('/job:worker/task:0/cpu:0'):
199        sum3 = sum1 + sum2
200    with session.Session(server1.target, config=config, graph=g):
201      output = self.evaluate(sum3)
202    self.assertEqual(40, output)
203
204  def testLegacyDeviceNames(self):
205    server1 = server_lib.Server.create_local_server()
206    server2 = server_lib.Server.create_local_server()
207    cluster_def = cluster_pb2.ClusterDef()
208    job = cluster_def.job.add()
209    job.name = 'worker'
210    job.tasks[0] = server1.target[len('grpc://'):]
211    job.tasks[1] = server2.target[len('grpc://'):]
212    config = config_pb2.ConfigProto(cluster_def=cluster_def)
213
214    with ops.Graph().as_default() as g, ops.device('/job:worker/task:1/cpu:0'):
215      const = constant_op.constant(17)
216    sess = session.Session(server1.target, config=config, graph=g)
217    run_options = config_pb2.RunOptions(
218        trace_level=config_pb2.RunOptions.FULL_TRACE)
219    run_metadata = config_pb2.RunMetadata()
220    output = sess.run(const, options=run_options, run_metadata=run_metadata)
221    self.assertEqual(17, output)
222    self.assertEqual(1,
223                     len([
224                         node_stats
225                         for dev_stats in run_metadata.step_stats.dev_stats
226                         for node_stats in dev_stats.node_stats
227                         if '/job:worker/replica:0/task:1/device:CPU:0' ==
228                         dev_stats.device and 'Const' == node_stats.node_name
229                     ]))
230
231  def testClusterSpecPropagationThreeServers2Graphs(self):
232    """Boots 3 servers, creates 2 sessions, ensures appropriate operations.
233
234    We create 2 clusterspecs:
235     1. server2 as the master, server1 as a worker
236     2. server2 as the master, server3 as a worker
237
238    We ensure that variables on the workers are independent.
239    """
240    server1 = server_lib.Server.create_local_server()
241    server2 = server_lib.Server.create_local_server()
242    server3 = server_lib.Server.create_local_server()
243    cluster_def1 = cluster_pb2.ClusterDef()
244    job1 = cluster_def1.job.add()
245    job1.name = 'worker1'
246    job1.tasks[0] = server2.target[len('grpc://'):]
247    job1.tasks[1] = server1.target[len('grpc://'):]
248
249    cluster_def2 = cluster_pb2.ClusterDef()
250    job2 = cluster_def2.job.add()
251    job2.name = 'worker2'
252    job2.tasks[0] = server2.target[len('grpc://'):]
253    job2.tasks[1] = server3.target[len('grpc://'):]
254
255    config1 = config_pb2.ConfigProto(cluster_def=cluster_def1)
256    config2 = config_pb2.ConfigProto(cluster_def=cluster_def2)
257
258    with ops.Graph().as_default() as g1:
259      with ops.device('/job:worker1/task:1'):
260        var1 = variables.Variable(array_ops.zeros([2]), name='var1')
261        update_op1 = state_ops.assign_add(
262            var1, array_ops.ones([2]), name='var1_assign_add')
263        init1 = variables.global_variables_initializer()
264
265    with ops.Graph().as_default() as g2:
266      with ops.device('/job:worker2/task:1'):
267        var2 = variables.Variable(array_ops.zeros([2]), name='var2')
268        update_op2 = state_ops.assign_add(
269            var2, array_ops.ones([2]), name='var2_assign_add')
270        init2 = variables.global_variables_initializer()
271
272    sess1 = session.Session(server2.target, graph=g1, config=config1)
273    sess2 = session.Session(server2.target, graph=g2, config=config2)
274
275    init1.run(session=sess1)
276    init2.run(session=sess2)
277
278    expected_zeros = np.zeros([2])
279    expected_ones = np.ones([2])
280
281    self.assertAllEqual(expected_zeros, sess1.run(var1))
282    self.assertAllEqual(expected_zeros, sess2.run(var2))
283
284    self.assertAllEqual(expected_ones, sess1.run(update_op1))
285    self.assertAllEqual(expected_ones, sess1.run(var1))
286    self.assertAllEqual(expected_zeros, sess2.run(var2))
287    self.assertAllEqual(expected_ones, sess2.run(update_op2))
288    self.assertAllEqual(expected_ones + expected_ones, sess1.run(update_op1))
289    self.assertAllEqual(expected_ones, sess2.run(var2))
290    self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1))
291
292  def testClusterSpecPropagationThreeServers(self):
293    """Boots 3 servers, creates 2 sessions, ensures appropriate operations.
294
295    We create 2 clusterspecs:
296     1. server2 as the master, server1 as a worker
297     2. server2 as the master, server3 as a worker
298
299    We ensure that variables on the workers are independent.
300    """
301    server1 = server_lib.Server.create_local_server()
302    server2 = server_lib.Server.create_local_server()
303    server3 = server_lib.Server.create_local_server()
304    cluster_def1 = cluster_pb2.ClusterDef()
305    job1 = cluster_def1.job.add()
306    job1.name = 'worker'
307    job1.tasks[0] = server2.target[len('grpc://'):]
308    job1.tasks[1] = server1.target[len('grpc://'):]
309
310    cluster_def2 = cluster_pb2.ClusterDef()
311    job2 = cluster_def2.job.add()
312    job2.name = 'worker'
313    job2.tasks[0] = server2.target[len('grpc://'):]
314    job2.tasks[1] = server3.target[len('grpc://'):]
315
316    config1 = config_pb2.ConfigProto(cluster_def=cluster_def1)
317    config2 = config_pb2.ConfigProto(cluster_def=cluster_def2)
318
319    with ops.device('/job:worker/task:1'):
320      var = variables.Variable(array_ops.zeros([2]), name='var')
321      feed = array_ops.placeholder(dtypes.float32, shape=(2))
322      update_op = var.assign_add(feed)
323
324    sess1 = session.Session(server2.target, config=config1)
325    sess2 = session.Session(server2.target, config=config2)
326
327    variables.global_variables_initializer().run(session=sess1)
328    variables.global_variables_initializer().run(session=sess2)
329
330    expected_zeros = np.zeros([2])
331    expected_ones = np.ones([2])
332
333    self.assertAllEqual(expected_zeros, sess1.run(var))
334    self.assertAllEqual(expected_zeros, sess2.run(var))
335    self.assertAllEqual(expected_ones,
336                        sess1.run(update_op, feed_dict={feed: expected_ones}))
337    self.assertAllEqual(expected_ones, sess1.run(var))
338    self.assertAllEqual(expected_zeros, sess2.run(var))
339    self.assertAllEqual(expected_ones,
340                        sess2.run(update_op, feed_dict={feed: expected_ones}))
341    self.assertAllEqual(expected_ones + expected_ones,
342                        sess1.run(update_op, feed_dict={feed: expected_ones}))
343    self.assertAllEqual(expected_ones, sess2.run(var))
344    self.assertAllEqual(expected_ones + expected_ones, sess1.run(var))
345
346  def testClusterSpecPropagationThreeServersOneCluster(self):
347    """Boots 3 servers, ensures appropriate communication across workers.
348
349    Additionally, in this cluster, we ensure the master is not the 0-th worker.
350
351    Note: this test only uses one session.
352    """
353    server1 = server_lib.Server.create_local_server()
354    server2 = server_lib.Server.create_local_server()
355    server3 = server_lib.Server.create_local_server()
356    cluster_def = cluster_pb2.ClusterDef()
357    job = cluster_def.job.add()
358    job.name = 'worker'
359    job.tasks[0] = server3.target[len('grpc://'):]
360    job.tasks[1] = server2.target[len('grpc://'):]
361    job.tasks[2] = server1.target[len('grpc://'):]
362    config = config_pb2.ConfigProto(cluster_def=cluster_def)
363
364    # Add ops to the devices in non-linear order.
365
366    with ops.device('/job:worker/task:1'):
367      feed1 = array_ops.placeholder(dtypes.float32, shape=(2))
368      const1 = constant_op.constant(2.0)
369      mul1 = const1 * feed1
370
371    with ops.device('/job:worker/task:2'):
372      feed2 = array_ops.placeholder(dtypes.float32, shape=(2))
373      const2 = constant_op.constant(2.0)
374      mul2 = const2 * feed2
375
376    with ops.device('/job:worker/task:0'):
377      feed0 = array_ops.placeholder(dtypes.float32, shape=(2))
378      const0 = constant_op.constant(2.0)
379      mul0 = const0 * feed0
380
381    sum_op = mul0 + mul1 + mul2
382
383    ones = np.ones([2])
384    run_options = config_pb2.RunOptions(
385        trace_level=config_pb2.RunOptions.FULL_TRACE)
386    run_metadata = config_pb2.RunMetadata()
387
388    # Run!
389    with session.Session(server1.target, config=config) as sess:
390      output = sess.run(
391          sum_op,
392          options=run_options,
393          run_metadata=run_metadata,
394          feed_dict={feed1: ones,
395                     feed2: ones,
396                     feed0: ones})
397      self.assertAllEqual(6 * ones, output)
398
399      self.assertEqual(
400          3,
401          len([
402              dev_stats.device
403              for dev_stats in run_metadata.step_stats.dev_stats
404              for node_stats in dev_stats.node_stats
405              if '/job:worker/replica:0/task:' in dev_stats.device and
406              node_stats.node_name.startswith('Const')
407          ]), run_metadata)
408
409  def testClusterSpecPropagationIsolation(self):
410    """Test that two sessions using ClusterSpec propagation are isolated."""
411    server = server_lib.Server.create_local_server()
412    init_value = array_ops.placeholder(dtypes.int32, shape=[])
413    v = variables.Variable(init_value)
414
415    cluster_def = cluster_pb2.ClusterDef()
416    job = cluster_def.job.add()
417    job.name = 'worker'
418    job.tasks[0] = server.target[len('grpc://'):]
419    config = config_pb2.ConfigProto(cluster_def=cluster_def)
420
421    sess1 = session.Session(server.target, config=config)
422    sess2 = session.Session(server.target, config=config)
423
424    # Initially, the variable is uninitialized in both sessions.
425    with self.assertRaises(errors.FailedPreconditionError):
426      sess1.run(v)
427    with self.assertRaises(errors.FailedPreconditionError):
428      sess2.run(v)
429
430    # An update in sess1 should be visible in sess1 only.
431    sess1.run(v.initializer, feed_dict={init_value: 37})
432    self.assertEqual(37, sess1.run(v))
433    with self.assertRaises(errors.FailedPreconditionError):
434      sess2.run(v)
435
436    # An update in sess2 should be visible in sess2 only.
437    sess2.run(v.initializer, feed_dict={init_value: 86})
438    self.assertEqual(37, sess1.run(v))
439    self.assertEqual(86, sess2.run(v))
440
441    # Closing sess2 has no effect on the state of sess1.
442    sess2.close()
443    self.assertEqual(37, sess1.run(v))
444
445    # Subsequent sessions will not see the state of existing sessions.
446    sess3 = session.Session(server.target, config=config)
447    self.assertEqual(37, sess1.run(v))
448    with self.assertRaises(errors.FailedPreconditionError):
449      sess3.run(v)
450
451  def testClusterSpecPropagationNonIsolation(self):
452    """Test that two sessions using ClusterSpec propagation shares state.
453
454    For example, the updated Variable value are visible among all worker
455    sessions registered in the same server.
456    """
457    server = server_lib.Server.create_local_server()
458    init_value = array_ops.placeholder(dtypes.int32, shape=[])
459    v = variables.Variable(init_value)
460
461    cluster_def = cluster_pb2.ClusterDef()
462    job = cluster_def.job.add()
463    job.name = 'worker'
464    job.tasks[0] = server.target[len('grpc://'):]
465    config = config_pb2.ConfigProto(cluster_def=cluster_def)
466    config.experimental.share_session_state_in_clusterspec_propagation = True
467
468    sess1 = session.Session(server.target, config=config)
469    sess2 = session.Session(server.target, config=config)
470
471    # Initially, the variable is uninitialized in both sessions.
472    with self.assertRaises(errors.FailedPreconditionError):
473      sess1.run(v)
474    with self.assertRaises(errors.FailedPreconditionError):
475      sess2.run(v)
476
477    # An update in sess1 should be visible in sess2.
478    sess1.run(v.initializer, feed_dict={init_value: 37})
479    self.assertEqual(37, sess1.run(v))
480    self.assertEqual(37, sess2.run(v))
481
482    # Closing sess2 has no effect on the state of sess1.
483    sess2.close()
484    self.assertEqual(37, sess1.run(v))
485
486    # Subsequent sessions should see the state of existing sessions.
487    sess3 = session.Session(server.target, config=config)
488    self.assertEqual(37, sess1.run(v))
489    self.assertEqual(37, sess3.run(v))
490
491  def testClusterSpecPropagationNonIsolation2Graphs(self):
492    """Creates 2 sessions with each own graph, ensures appropriate operations.
493
494    We ensure that variables on the workers shares state.
495    """
496    server = server_lib.Server.create_local_server()
497    cluster_def = cluster_pb2.ClusterDef()
498    job = cluster_def.job.add()
499    job.name = 'worker'
500    job.tasks[0] = server.target[len('grpc://'):]
501    config = config_pb2.ConfigProto(cluster_def=cluster_def)
502    config.experimental.share_session_state_in_clusterspec_propagation = True
503
504    with ops.Graph().as_default() as g1:
505      var1 = variables.Variable(array_ops.zeros([2]), name='var')
506      update_op1 = state_ops.assign_add(
507          var1, array_ops.ones([2]), name='var1_assign_add')
508      init1 = variables.global_variables_initializer()
509
510    with ops.Graph().as_default() as g2:
511      var2 = variables.Variable(array_ops.zeros([2]), name='var')
512      update_op2 = state_ops.assign_add(
513          var2, array_ops.ones([2]), name='var2_assign_add')
514
515    sess1 = session.Session(server.target, graph=g1, config=config)
516    sess2 = session.Session(server.target, graph=g2, config=config)
517
518    expected_zeros = np.zeros([2])
519    expected_ones = np.ones([2])
520
521    init1.run(session=sess1)
522    self.assertAllEqual(expected_zeros, sess1.run(var1))
523    self.assertAllEqual(expected_zeros, sess2.run(var2))
524
525    self.assertAllEqual(expected_ones, sess1.run(update_op1))
526    self.assertAllEqual(expected_ones, sess1.run(var1))
527    self.assertAllEqual(expected_ones, sess2.run(var2))
528    self.assertAllEqual(expected_ones + expected_ones, sess2.run(update_op2))
529    self.assertAllEqual(expected_ones + expected_ones, sess2.run(var2))
530    self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1))
531
532  def testClusterSpecPropagationPartialRun(self):
533    """Test successful partial run with ClusterSpec propagation."""
534    server1 = server_lib.Server.create_local_server()
535    server2 = server_lib.Server.create_local_server()
536
537    cluster_def = cluster_pb2.ClusterDef()
538    job = cluster_def.job.add()
539    job.name = 'worker'
540    job.tasks[0] = server1.target[len('grpc://'):]
541    job.tasks[1] = server2.target[len('grpc://'):]
542    config = config_pb2.ConfigProto(cluster_def=cluster_def)
543
544    with ops.device('/job:worker/task:0'):
545      a = array_ops.placeholder(dtypes.float32, shape=[])
546    with ops.device('/job:worker/task:1'):
547      b = array_ops.placeholder(dtypes.float32, shape=[])
548      c = array_ops.placeholder(dtypes.float32, shape=[])
549      r1 = math_ops.add(a, b)
550    with ops.device('/job:worker/task:0'):
551      r2 = math_ops.multiply(r1, c)
552
553    with session.Session(server1.target, config=config) as sess:
554      h = sess.partial_run_setup([r1, r2], [a, b, c])
555      res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
556      self.assertEqual(3, res)
557      res = sess.partial_run(h, r2, feed_dict={c: 3})
558      self.assertEqual(9, res)
559
560
561if __name__ == '__main__':
562  googletest.main()
563