1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.client.session.Session's ClusterSpec Propagation. 16 17These tests exercise the ClusterSpec Propagation capabilities of distributed 18Sessions. 19""" 20import numpy as np 21 22from tensorflow.core.protobuf import cluster_pb2 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python.client import session 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import math_ops 32# Import resource_variable_ops for the variables-to-tensor implicit conversion. 33from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 34from tensorflow.python.ops import state_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import googletest 37from tensorflow.python.platform import test 38from tensorflow.python.training import server_lib 39 40 41class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): 42 43 def testClusterSpecPropagationSimple(self): 44 server1 = server_lib.Server.create_local_server() 45 server2 = server_lib.Server.create_local_server() 46 cluster_def = cluster_pb2.ClusterDef() 47 job = cluster_def.job.add() 48 job.name = 'worker' 49 job.tasks[0] = server1.target[len('grpc://'):] 50 job.tasks[1] = server2.target[len('grpc://'):] 51 config = config_pb2.ConfigProto(cluster_def=cluster_def) 52 53 const = constant_op.constant(17) 54 sess = session.Session(server1.target, config=config) 55 output = self.evaluate(const) 56 self.assertEqual(17, output) 57 58 def testClusterSpecPropagationWorker2Placement(self): 59 server1 = server_lib.Server.create_local_server() 60 server2 = server_lib.Server.create_local_server() 61 cluster_def = cluster_pb2.ClusterDef() 62 job = cluster_def.job.add() 63 job.name = 'worker' 64 job.tasks[0] = server1.target[len('grpc://'):] 65 job.tasks[1] = server2.target[len('grpc://'):] 66 config = config_pb2.ConfigProto(cluster_def=cluster_def) 67 68 with ops.Graph().as_default() as g, ops.device('/job:worker/task:1'): 69 with ops.device('/cpu:0'): 70 const = constant_op.constant(17) 71 sess = session.Session(server1.target, config=config, graph=g) 72 run_options = config_pb2.RunOptions( 73 trace_level=config_pb2.RunOptions.FULL_TRACE) 74 run_metadata = config_pb2.RunMetadata() 75 output = sess.run(const, options=run_options, run_metadata=run_metadata) 76 self.assertEqual(17, output) 77 self.assertEqual(1, 78 len([ 79 node_stats 80 for dev_stats in run_metadata.step_stats.dev_stats 81 for node_stats in dev_stats.node_stats 82 if '/job:worker/replica:0/task:1/device:CPU:0' == 83 dev_stats.device and 'Const' == node_stats.node_name 84 ])) 85 86 def testClusterSpecPropagationWorker1Placement(self): 87 server1 = server_lib.Server.create_local_server() 88 server2 = server_lib.Server.create_local_server() 89 cluster_def = cluster_pb2.ClusterDef() 90 job = cluster_def.job.add() 91 job.name = 'worker' 92 job.tasks[0] = server1.target[len('grpc://'):] 93 job.tasks[1] = server2.target[len('grpc://'):] 94 config = config_pb2.ConfigProto(cluster_def=cluster_def) 95 96 with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'): 97 const = constant_op.constant(17) 98 with session.Session(server1.target, config=config, graph=g): 99 output = self.evaluate(const) 100 self.assertEqual(17, output) 101 102 def testCanonicalDeviceNames(self): 103 server1 = server_lib.Server.create_local_server() 104 server2 = server_lib.Server.create_local_server() 105 cluster_def = cluster_pb2.ClusterDef() 106 job = cluster_def.job.add() 107 job.name = 'worker' 108 job.tasks[0] = server1.target[len('grpc://'):] 109 job.tasks[1] = server2.target[len('grpc://'):] 110 config = config_pb2.ConfigProto(cluster_def=cluster_def) 111 112 with ops.Graph().as_default() as g, ops.device( 113 '/job:worker/task:1/device:CPU:0'): 114 const = constant_op.constant(17) 115 sess = session.Session(server1.target, config=config, graph=g) 116 run_options = config_pb2.RunOptions( 117 trace_level=config_pb2.RunOptions.FULL_TRACE) 118 run_metadata = config_pb2.RunMetadata() 119 output = sess.run(const, options=run_options, run_metadata=run_metadata) 120 self.assertEqual(17, output) 121 self.assertEqual(1, 122 len([ 123 node_stats 124 for dev_stats in run_metadata.step_stats.dev_stats 125 for node_stats in dev_stats.node_stats 126 if '/job:worker/replica:0/task:1/device:CPU:0' == 127 dev_stats.device and 'Const' == node_stats.node_name 128 ])) 129 130 def testFullDeviceNames(self): 131 server1 = server_lib.Server.create_local_server() 132 server2 = server_lib.Server.create_local_server() 133 cluster_def = cluster_pb2.ClusterDef() 134 job = cluster_def.job.add() 135 job.name = 'renamed_worker' 136 job.tasks[0] = server1.target[len('grpc://'):] 137 job.tasks[1] = server2.target[len('grpc://'):] 138 config = config_pb2.ConfigProto(cluster_def=cluster_def) 139 140 with ops.Graph().as_default() as g, ops.device( 141 '/job:renamed_worker/replica:0/task:1/device:CPU:0'): 142 const = constant_op.constant(17) 143 sess = session.Session(server1.target, config=config, graph=g) 144 run_options = config_pb2.RunOptions( 145 trace_level=config_pb2.RunOptions.FULL_TRACE) 146 run_metadata = config_pb2.RunMetadata() 147 output = sess.run(const, options=run_options, run_metadata=run_metadata) 148 self.assertEqual(17, output) 149 self.assertEqual(1, 150 len([ 151 node_stats 152 for dev_stats in run_metadata.step_stats.dev_stats 153 for node_stats in dev_stats.node_stats 154 if '/job:renamed_worker/replica:0/task:1/device:CPU:0' 155 == dev_stats.device and 'Const' == node_stats.node_name 156 ])) 157 158 def testMultipleLocalDevices(self): 159 # Note: CPU->CPU transfers have a fast-path in 160 # BaseRemoteRendezvous::SameWorkerRecvDone that means the test doesn't 161 # actually capture the motivating bug unless run on a GPU machine. 162 # 163 # Example error message (before bugfix -- line breaks added because lint): 164 # 165 # W0718 17:14:41.521534 190121 device_mgr.cc:107] Unknown device: 166 # /job:worker/replica:0/task:0/device:CPU:0 all devices: 167 # /job:local/replica:0/task:0/device:GPU:0, 168 # /job:local/replica:0/task:0/device:GPU:0, 169 # /job:local/replica:0/task:0/cpu:1, CPU:0, GPU:0, 170 # /job:local/replica:0/task:0/device:CPU:1, 171 # /job:local/replica:0/task:0/device:CPU:0, CPU:1, 172 # /job:local/replica:0/task:0/cpu:0 173 server_config = config_pb2.ConfigProto(device_count={'CPU': 2}) 174 server1 = server_lib.Server.create_local_server(config=server_config) 175 server2 = server_lib.Server.create_local_server(config=server_config) 176 cluster_def = cluster_pb2.ClusterDef() 177 job = cluster_def.job.add() 178 job.name = 'worker' 179 job.tasks[0] = server1.target[len('grpc://'):] 180 job.tasks[1] = server2.target[len('grpc://'):] 181 config = config_pb2.ConfigProto(cluster_def=cluster_def) 182 183 with ops.Graph().as_default() as g: 184 with ops.device('/job:worker/task:1/cpu:1'): 185 input1 = constant_op.constant(17, dtypes.float32) 186 with ops.device('/job:worker/task:0/cpu:1'): 187 input2 = constant_op.constant(3, dtypes.float32) 188 with ops.device('/job:worker/task:1/cpu:0'): 189 sum1 = input1 + input2 190 191 if test.is_gpu_available(): 192 device_str = '/job:worker/task:0/device:GPU:0' 193 else: 194 device_str = '/job:worker/task:0/cpu:1' 195 with ops.device(device_str): 196 sum2 = input2 + input1 197 198 with ops.device('/job:worker/task:0/cpu:0'): 199 sum3 = sum1 + sum2 200 with session.Session(server1.target, config=config, graph=g): 201 output = self.evaluate(sum3) 202 self.assertEqual(40, output) 203 204 def testLegacyDeviceNames(self): 205 server1 = server_lib.Server.create_local_server() 206 server2 = server_lib.Server.create_local_server() 207 cluster_def = cluster_pb2.ClusterDef() 208 job = cluster_def.job.add() 209 job.name = 'worker' 210 job.tasks[0] = server1.target[len('grpc://'):] 211 job.tasks[1] = server2.target[len('grpc://'):] 212 config = config_pb2.ConfigProto(cluster_def=cluster_def) 213 214 with ops.Graph().as_default() as g, ops.device('/job:worker/task:1/cpu:0'): 215 const = constant_op.constant(17) 216 sess = session.Session(server1.target, config=config, graph=g) 217 run_options = config_pb2.RunOptions( 218 trace_level=config_pb2.RunOptions.FULL_TRACE) 219 run_metadata = config_pb2.RunMetadata() 220 output = sess.run(const, options=run_options, run_metadata=run_metadata) 221 self.assertEqual(17, output) 222 self.assertEqual(1, 223 len([ 224 node_stats 225 for dev_stats in run_metadata.step_stats.dev_stats 226 for node_stats in dev_stats.node_stats 227 if '/job:worker/replica:0/task:1/device:CPU:0' == 228 dev_stats.device and 'Const' == node_stats.node_name 229 ])) 230 231 def testClusterSpecPropagationThreeServers2Graphs(self): 232 """Boots 3 servers, creates 2 sessions, ensures appropriate operations. 233 234 We create 2 clusterspecs: 235 1. server2 as the master, server1 as a worker 236 2. server2 as the master, server3 as a worker 237 238 We ensure that variables on the workers are independent. 239 """ 240 server1 = server_lib.Server.create_local_server() 241 server2 = server_lib.Server.create_local_server() 242 server3 = server_lib.Server.create_local_server() 243 cluster_def1 = cluster_pb2.ClusterDef() 244 job1 = cluster_def1.job.add() 245 job1.name = 'worker1' 246 job1.tasks[0] = server2.target[len('grpc://'):] 247 job1.tasks[1] = server1.target[len('grpc://'):] 248 249 cluster_def2 = cluster_pb2.ClusterDef() 250 job2 = cluster_def2.job.add() 251 job2.name = 'worker2' 252 job2.tasks[0] = server2.target[len('grpc://'):] 253 job2.tasks[1] = server3.target[len('grpc://'):] 254 255 config1 = config_pb2.ConfigProto(cluster_def=cluster_def1) 256 config2 = config_pb2.ConfigProto(cluster_def=cluster_def2) 257 258 with ops.Graph().as_default() as g1: 259 with ops.device('/job:worker1/task:1'): 260 var1 = variables.Variable(array_ops.zeros([2]), name='var1') 261 update_op1 = state_ops.assign_add( 262 var1, array_ops.ones([2]), name='var1_assign_add') 263 init1 = variables.global_variables_initializer() 264 265 with ops.Graph().as_default() as g2: 266 with ops.device('/job:worker2/task:1'): 267 var2 = variables.Variable(array_ops.zeros([2]), name='var2') 268 update_op2 = state_ops.assign_add( 269 var2, array_ops.ones([2]), name='var2_assign_add') 270 init2 = variables.global_variables_initializer() 271 272 sess1 = session.Session(server2.target, graph=g1, config=config1) 273 sess2 = session.Session(server2.target, graph=g2, config=config2) 274 275 init1.run(session=sess1) 276 init2.run(session=sess2) 277 278 expected_zeros = np.zeros([2]) 279 expected_ones = np.ones([2]) 280 281 self.assertAllEqual(expected_zeros, sess1.run(var1)) 282 self.assertAllEqual(expected_zeros, sess2.run(var2)) 283 284 self.assertAllEqual(expected_ones, sess1.run(update_op1)) 285 self.assertAllEqual(expected_ones, sess1.run(var1)) 286 self.assertAllEqual(expected_zeros, sess2.run(var2)) 287 self.assertAllEqual(expected_ones, sess2.run(update_op2)) 288 self.assertAllEqual(expected_ones + expected_ones, sess1.run(update_op1)) 289 self.assertAllEqual(expected_ones, sess2.run(var2)) 290 self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1)) 291 292 def testClusterSpecPropagationThreeServers(self): 293 """Boots 3 servers, creates 2 sessions, ensures appropriate operations. 294 295 We create 2 clusterspecs: 296 1. server2 as the master, server1 as a worker 297 2. server2 as the master, server3 as a worker 298 299 We ensure that variables on the workers are independent. 300 """ 301 server1 = server_lib.Server.create_local_server() 302 server2 = server_lib.Server.create_local_server() 303 server3 = server_lib.Server.create_local_server() 304 cluster_def1 = cluster_pb2.ClusterDef() 305 job1 = cluster_def1.job.add() 306 job1.name = 'worker' 307 job1.tasks[0] = server2.target[len('grpc://'):] 308 job1.tasks[1] = server1.target[len('grpc://'):] 309 310 cluster_def2 = cluster_pb2.ClusterDef() 311 job2 = cluster_def2.job.add() 312 job2.name = 'worker' 313 job2.tasks[0] = server2.target[len('grpc://'):] 314 job2.tasks[1] = server3.target[len('grpc://'):] 315 316 config1 = config_pb2.ConfigProto(cluster_def=cluster_def1) 317 config2 = config_pb2.ConfigProto(cluster_def=cluster_def2) 318 319 with ops.device('/job:worker/task:1'): 320 var = variables.Variable(array_ops.zeros([2]), name='var') 321 feed = array_ops.placeholder(dtypes.float32, shape=(2)) 322 update_op = var.assign_add(feed) 323 324 sess1 = session.Session(server2.target, config=config1) 325 sess2 = session.Session(server2.target, config=config2) 326 327 variables.global_variables_initializer().run(session=sess1) 328 variables.global_variables_initializer().run(session=sess2) 329 330 expected_zeros = np.zeros([2]) 331 expected_ones = np.ones([2]) 332 333 self.assertAllEqual(expected_zeros, sess1.run(var)) 334 self.assertAllEqual(expected_zeros, sess2.run(var)) 335 self.assertAllEqual(expected_ones, 336 sess1.run(update_op, feed_dict={feed: expected_ones})) 337 self.assertAllEqual(expected_ones, sess1.run(var)) 338 self.assertAllEqual(expected_zeros, sess2.run(var)) 339 self.assertAllEqual(expected_ones, 340 sess2.run(update_op, feed_dict={feed: expected_ones})) 341 self.assertAllEqual(expected_ones + expected_ones, 342 sess1.run(update_op, feed_dict={feed: expected_ones})) 343 self.assertAllEqual(expected_ones, sess2.run(var)) 344 self.assertAllEqual(expected_ones + expected_ones, sess1.run(var)) 345 346 def testClusterSpecPropagationThreeServersOneCluster(self): 347 """Boots 3 servers, ensures appropriate communication across workers. 348 349 Additionally, in this cluster, we ensure the master is not the 0-th worker. 350 351 Note: this test only uses one session. 352 """ 353 server1 = server_lib.Server.create_local_server() 354 server2 = server_lib.Server.create_local_server() 355 server3 = server_lib.Server.create_local_server() 356 cluster_def = cluster_pb2.ClusterDef() 357 job = cluster_def.job.add() 358 job.name = 'worker' 359 job.tasks[0] = server3.target[len('grpc://'):] 360 job.tasks[1] = server2.target[len('grpc://'):] 361 job.tasks[2] = server1.target[len('grpc://'):] 362 config = config_pb2.ConfigProto(cluster_def=cluster_def) 363 364 # Add ops to the devices in non-linear order. 365 366 with ops.device('/job:worker/task:1'): 367 feed1 = array_ops.placeholder(dtypes.float32, shape=(2)) 368 const1 = constant_op.constant(2.0) 369 mul1 = const1 * feed1 370 371 with ops.device('/job:worker/task:2'): 372 feed2 = array_ops.placeholder(dtypes.float32, shape=(2)) 373 const2 = constant_op.constant(2.0) 374 mul2 = const2 * feed2 375 376 with ops.device('/job:worker/task:0'): 377 feed0 = array_ops.placeholder(dtypes.float32, shape=(2)) 378 const0 = constant_op.constant(2.0) 379 mul0 = const0 * feed0 380 381 sum_op = mul0 + mul1 + mul2 382 383 ones = np.ones([2]) 384 run_options = config_pb2.RunOptions( 385 trace_level=config_pb2.RunOptions.FULL_TRACE) 386 run_metadata = config_pb2.RunMetadata() 387 388 # Run! 389 with session.Session(server1.target, config=config) as sess: 390 output = sess.run( 391 sum_op, 392 options=run_options, 393 run_metadata=run_metadata, 394 feed_dict={feed1: ones, 395 feed2: ones, 396 feed0: ones}) 397 self.assertAllEqual(6 * ones, output) 398 399 self.assertEqual( 400 3, 401 len([ 402 dev_stats.device 403 for dev_stats in run_metadata.step_stats.dev_stats 404 for node_stats in dev_stats.node_stats 405 if '/job:worker/replica:0/task:' in dev_stats.device and 406 node_stats.node_name.startswith('Const') 407 ]), run_metadata) 408 409 def testClusterSpecPropagationIsolation(self): 410 """Test that two sessions using ClusterSpec propagation are isolated.""" 411 server = server_lib.Server.create_local_server() 412 init_value = array_ops.placeholder(dtypes.int32, shape=[]) 413 v = variables.Variable(init_value) 414 415 cluster_def = cluster_pb2.ClusterDef() 416 job = cluster_def.job.add() 417 job.name = 'worker' 418 job.tasks[0] = server.target[len('grpc://'):] 419 config = config_pb2.ConfigProto(cluster_def=cluster_def) 420 421 sess1 = session.Session(server.target, config=config) 422 sess2 = session.Session(server.target, config=config) 423 424 # Initially, the variable is uninitialized in both sessions. 425 with self.assertRaises(errors.FailedPreconditionError): 426 sess1.run(v) 427 with self.assertRaises(errors.FailedPreconditionError): 428 sess2.run(v) 429 430 # An update in sess1 should be visible in sess1 only. 431 sess1.run(v.initializer, feed_dict={init_value: 37}) 432 self.assertEqual(37, sess1.run(v)) 433 with self.assertRaises(errors.FailedPreconditionError): 434 sess2.run(v) 435 436 # An update in sess2 should be visible in sess2 only. 437 sess2.run(v.initializer, feed_dict={init_value: 86}) 438 self.assertEqual(37, sess1.run(v)) 439 self.assertEqual(86, sess2.run(v)) 440 441 # Closing sess2 has no effect on the state of sess1. 442 sess2.close() 443 self.assertEqual(37, sess1.run(v)) 444 445 # Subsequent sessions will not see the state of existing sessions. 446 sess3 = session.Session(server.target, config=config) 447 self.assertEqual(37, sess1.run(v)) 448 with self.assertRaises(errors.FailedPreconditionError): 449 sess3.run(v) 450 451 def testClusterSpecPropagationNonIsolation(self): 452 """Test that two sessions using ClusterSpec propagation shares state. 453 454 For example, the updated Variable value are visible among all worker 455 sessions registered in the same server. 456 """ 457 server = server_lib.Server.create_local_server() 458 init_value = array_ops.placeholder(dtypes.int32, shape=[]) 459 v = variables.Variable(init_value) 460 461 cluster_def = cluster_pb2.ClusterDef() 462 job = cluster_def.job.add() 463 job.name = 'worker' 464 job.tasks[0] = server.target[len('grpc://'):] 465 config = config_pb2.ConfigProto(cluster_def=cluster_def) 466 config.experimental.share_session_state_in_clusterspec_propagation = True 467 468 sess1 = session.Session(server.target, config=config) 469 sess2 = session.Session(server.target, config=config) 470 471 # Initially, the variable is uninitialized in both sessions. 472 with self.assertRaises(errors.FailedPreconditionError): 473 sess1.run(v) 474 with self.assertRaises(errors.FailedPreconditionError): 475 sess2.run(v) 476 477 # An update in sess1 should be visible in sess2. 478 sess1.run(v.initializer, feed_dict={init_value: 37}) 479 self.assertEqual(37, sess1.run(v)) 480 self.assertEqual(37, sess2.run(v)) 481 482 # Closing sess2 has no effect on the state of sess1. 483 sess2.close() 484 self.assertEqual(37, sess1.run(v)) 485 486 # Subsequent sessions should see the state of existing sessions. 487 sess3 = session.Session(server.target, config=config) 488 self.assertEqual(37, sess1.run(v)) 489 self.assertEqual(37, sess3.run(v)) 490 491 def testClusterSpecPropagationNonIsolation2Graphs(self): 492 """Creates 2 sessions with each own graph, ensures appropriate operations. 493 494 We ensure that variables on the workers shares state. 495 """ 496 server = server_lib.Server.create_local_server() 497 cluster_def = cluster_pb2.ClusterDef() 498 job = cluster_def.job.add() 499 job.name = 'worker' 500 job.tasks[0] = server.target[len('grpc://'):] 501 config = config_pb2.ConfigProto(cluster_def=cluster_def) 502 config.experimental.share_session_state_in_clusterspec_propagation = True 503 504 with ops.Graph().as_default() as g1: 505 var1 = variables.Variable(array_ops.zeros([2]), name='var') 506 update_op1 = state_ops.assign_add( 507 var1, array_ops.ones([2]), name='var1_assign_add') 508 init1 = variables.global_variables_initializer() 509 510 with ops.Graph().as_default() as g2: 511 var2 = variables.Variable(array_ops.zeros([2]), name='var') 512 update_op2 = state_ops.assign_add( 513 var2, array_ops.ones([2]), name='var2_assign_add') 514 515 sess1 = session.Session(server.target, graph=g1, config=config) 516 sess2 = session.Session(server.target, graph=g2, config=config) 517 518 expected_zeros = np.zeros([2]) 519 expected_ones = np.ones([2]) 520 521 init1.run(session=sess1) 522 self.assertAllEqual(expected_zeros, sess1.run(var1)) 523 self.assertAllEqual(expected_zeros, sess2.run(var2)) 524 525 self.assertAllEqual(expected_ones, sess1.run(update_op1)) 526 self.assertAllEqual(expected_ones, sess1.run(var1)) 527 self.assertAllEqual(expected_ones, sess2.run(var2)) 528 self.assertAllEqual(expected_ones + expected_ones, sess2.run(update_op2)) 529 self.assertAllEqual(expected_ones + expected_ones, sess2.run(var2)) 530 self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1)) 531 532 def testClusterSpecPropagationPartialRun(self): 533 """Test successful partial run with ClusterSpec propagation.""" 534 server1 = server_lib.Server.create_local_server() 535 server2 = server_lib.Server.create_local_server() 536 537 cluster_def = cluster_pb2.ClusterDef() 538 job = cluster_def.job.add() 539 job.name = 'worker' 540 job.tasks[0] = server1.target[len('grpc://'):] 541 job.tasks[1] = server2.target[len('grpc://'):] 542 config = config_pb2.ConfigProto(cluster_def=cluster_def) 543 544 with ops.device('/job:worker/task:0'): 545 a = array_ops.placeholder(dtypes.float32, shape=[]) 546 with ops.device('/job:worker/task:1'): 547 b = array_ops.placeholder(dtypes.float32, shape=[]) 548 c = array_ops.placeholder(dtypes.float32, shape=[]) 549 r1 = math_ops.add(a, b) 550 with ops.device('/job:worker/task:0'): 551 r2 = math_ops.multiply(r1, c) 552 553 with session.Session(server1.target, config=config) as sess: 554 h = sess.partial_run_setup([r1, r2], [a, b, c]) 555 res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) 556 self.assertEqual(3, res) 557 res = sess.partial_run(h, r2, feed_dict={c: 3}) 558 self.assertEqual(9, res) 559 560 561if __name__ == '__main__': 562 googletest.main() 563