1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Graph actions tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import shutil 22import tempfile 23 24from tensorflow.contrib import testing 25from tensorflow.contrib.framework.python.framework import checkpoint_utils 26from tensorflow.contrib.framework.python.ops import variables as variables_lib 27from tensorflow.contrib.learn.python import learn 28from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import test_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import resources 34from tensorflow.python.ops import state_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import test 37from tensorflow.python.summary import summary 38from tensorflow.python.training import checkpoint_management 39from tensorflow.python.training import saver as saver_lib 40 41 42class _Feeder(object): 43 """Simple generator for `feed_fn`, returning 10 * step.""" 44 45 def __init__(self, tensor, max_step): 46 self._step = 0 47 self._tensor = tensor 48 self._max_step = max_step 49 50 @property 51 def step(self): 52 return self._step 53 54 def feed_fn(self): 55 if self._step >= self._max_step: 56 raise StopIteration 57 value = self._step * 10.0 58 self._step += 1 59 return {self._tensor: value} 60 61 62class _BaseMonitorWrapper(BaseMonitor): 63 """Base monitor wrapper to facilitate testing. 64 65 This monitor can act as either chief-exclusive or non-exclusive. 66 """ 67 68 def __init__(self, run_on_all_workers): 69 super(_BaseMonitorWrapper, self).__init__() 70 self._run_on_all_workers = run_on_all_workers 71 self._is_active = False 72 self._has_step = False 73 74 @property 75 def run_on_all_workers(self): 76 return self._run_on_all_workers 77 78 @property 79 def is_active(self): 80 return self._is_active 81 82 @property 83 def has_step(self): 84 return self._has_step 85 86 def begin(self, max_steps=None): 87 self._is_active = True 88 return super(_BaseMonitorWrapper, self).begin(max_steps) 89 90 def step_begin(self, step): 91 self._has_step = True 92 return super(_BaseMonitorWrapper, self).step_begin(step) 93 94 95class GraphActionsTest(test.TestCase): 96 """Graph actions tests.""" 97 98 def setUp(self): 99 learn.graph_actions.clear_summary_writers() 100 self._output_dir = tempfile.mkdtemp() 101 testing.FakeSummaryWriter.install() 102 103 def tearDown(self): 104 testing.FakeSummaryWriter.uninstall() 105 if self._output_dir: 106 shutil.rmtree(self._output_dir) 107 learn.graph_actions.clear_summary_writers() 108 109 def _assert_summaries(self, 110 output_dir, 111 writer, 112 expected_summaries=None, 113 expected_graphs=None, 114 expected_meta_graphs=None, 115 expected_session_logs=None): 116 self.assertTrue(isinstance(writer, testing.FakeSummaryWriter)) 117 writer.assert_summaries( 118 self, 119 expected_logdir=output_dir, 120 expected_graph=ops.get_default_graph(), 121 expected_summaries=expected_summaries, 122 expected_added_graphs=expected_graphs, 123 expected_added_meta_graphs=expected_meta_graphs, 124 expected_session_logs=expected_session_logs) 125 126 # TODO(ptucker): Test number and contents of checkpoint files. 127 def _assert_ckpt(self, output_dir, expected=True): 128 ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) 129 if expected: 130 pattern = '%s/model.ckpt-.*' % output_dir 131 primary_ckpt_path = ckpt_state.model_checkpoint_path 132 self.assertRegexpMatches(primary_ckpt_path, pattern) 133 all_ckpt_paths = ckpt_state.all_model_checkpoint_paths 134 self.assertTrue(primary_ckpt_path in all_ckpt_paths) 135 for ckpt_path in all_ckpt_paths: 136 self.assertRegexpMatches(ckpt_path, pattern) 137 else: 138 self.assertTrue(ckpt_state is None) 139 140 # TODO(ptucker): Test lock, multi-threaded access? 141 def test_summary_writer(self): 142 writer = learn.graph_actions.get_summary_writer('log/dir/0') 143 self._assert_summaries('log/dir/0', writer) 144 self.assertTrue( 145 learn.graph_actions.get_summary_writer('log/dir/0') is 146 learn.graph_actions.get_summary_writer('log/dir/0')) 147 self.assertTrue( 148 learn.graph_actions.get_summary_writer('log/dir/0') is 149 not learn.graph_actions.get_summary_writer('log/dir/1')) 150 151 # TODO(ptucker): Test restore_checkpoint_path for eval; this should obsolete 152 # test_evaluate_with_saver(). 153 # TODO(ptucker): Test start_queue_runners for both eval & train. 154 # TODO(ptucker): Test coord.request_stop & coord.join for eval. 155 156 def _build_inference_graph(self): 157 """Build simple inference graph. 158 159 This includes a regular variable, local variable, and fake table. 160 161 Returns: 162 Tuple of 3 `Tensor` objects, 2 input and 1 output. 163 """ 164 variables_lib.create_global_step() 165 in0 = variables.VariableV1(1.0) 166 in1 = variables_lib.local_variable(2.0) 167 fake_table = variables.VariableV1( 168 3.0, 169 trainable=False, 170 collections=['fake_tables'], 171 name='fake_table_var') 172 in0.graph.add_to_collections([ops.GraphKeys.TABLE_INITIALIZERS], 173 fake_table.initializer) 174 out = in0 + in1 + fake_table 175 return in0, in1, out 176 177 def test_infer(self): 178 with ops.Graph().as_default() as g, self.session(g): 179 self._assert_ckpt(self._output_dir, False) 180 in0, in1, out = self._build_inference_graph() 181 self.assertEqual({ 182 'a': 1.0, 183 'b': 2.0, 184 'c': 6.0 185 }, learn.graph_actions.infer(None, {'a': in0, 186 'b': in1, 187 'c': out})) 188 self._assert_ckpt(self._output_dir, False) 189 190 @test.mock.patch.object( 191 learn.graph_actions.coordinator.Coordinator, 192 'request_stop', 193 side_effect=learn.graph_actions.coordinator.Coordinator.request_stop, 194 autospec=True) 195 def test_coordinator_request_stop_called(self, request_stop): 196 with ops.Graph().as_default() as g, self.session(g): 197 in0, in1, out = self._build_inference_graph() 198 learn.graph_actions.infer(None, {'a': in0, 'b': in1, 'c': out}) 199 self.assertTrue(request_stop.called) 200 201 @test.mock.patch.object( 202 learn.graph_actions.coordinator.Coordinator, 203 'request_stop', 204 side_effect=learn.graph_actions.coordinator.Coordinator.request_stop, 205 autospec=True) 206 def test_run_feeds_iter_cleanup_with_exceptions(self, request_stop): 207 with ops.Graph().as_default() as g, self.session(g): 208 in0, in1, out = self._build_inference_graph() 209 try: 210 for _ in learn.graph_actions.run_feeds_iter({ 211 'a': in0, 212 'b': in1, 213 'c': out 214 }, [None] * 3): 215 self.assertFalse(request_stop.called) 216 raise ValueError('Fake exception') 217 except ValueError: 218 pass 219 self.assertTrue(request_stop.called) 220 221 def test_run_feeds_iter_calls_resources_init(self): 222 with ops.Graph().as_default(): 223 in0, _, _ = self._build_inference_graph() 224 handle = test_ops.stub_resource_handle_op(container='a', shared_name='b') 225 resources.register_resource( 226 handle=handle, 227 create_op=test_ops.resource_create_op(handle), 228 is_initialized_op=test_ops.resource_initialized_op(handle)) 229 230 for _ in learn.graph_actions.run_feeds_iter( 231 { 232 'in0': in0 233 }, feed_dicts=[{}]): 234 self.assertTrue(test_ops.resource_initialized_op(handle).eval()) 235 236 def test_infer_different_default_graph(self): 237 with self.cached_session(): 238 self._assert_ckpt(self._output_dir, False) 239 with ops.Graph().as_default(): 240 in0, in1, out = self._build_inference_graph() 241 with ops.Graph().as_default(): 242 self.assertEqual({ 243 'a': 1.0, 244 'b': 2.0, 245 'c': 6.0 246 }, learn.graph_actions.infer(None, {'a': in0, 247 'b': in1, 248 'c': out})) 249 self._assert_ckpt(self._output_dir, False) 250 251 def test_infer_invalid_feed(self): 252 with ops.Graph().as_default() as g, self.session(g): 253 self._assert_ckpt(self._output_dir, False) 254 in0, _, _ = self._build_inference_graph() 255 with self.assertRaisesRegexp(TypeError, 'Can not convert a NoneType'): 256 learn.graph_actions.infer(None, {'a': in0}, feed_dict={None: 4.0}) 257 self._assert_ckpt(self._output_dir, False) 258 259 def test_infer_feed(self): 260 with ops.Graph().as_default() as g, self.session(g): 261 self._assert_ckpt(self._output_dir, False) 262 in0, _, out = self._build_inference_graph() 263 self.assertEqual( 264 { 265 'c': 9.0 266 }, 267 learn.graph_actions.infer( 268 None, {'c': out}, feed_dict={in0: 4.0})) 269 self._assert_ckpt(self._output_dir, False) 270 271 # TODO(ptucker): Test eval for 1 epoch. 272 273 def test_evaluate_invalid_args(self): 274 with ops.Graph().as_default() as g, self.session(g): 275 self._assert_ckpt(self._output_dir, False) 276 with self.assertRaisesRegexp(ValueError, 'utput directory'): 277 learn.graph_actions.evaluate( 278 g, 279 output_dir=None, 280 checkpoint_path=None, 281 eval_dict={'a': constant_op.constant(1.0)}) 282 with self.assertRaisesRegexp(ValueError, 'utput directory'): 283 learn.graph_actions.evaluate( 284 g, 285 output_dir='', 286 checkpoint_path=None, 287 eval_dict={'a': constant_op.constant(1.0)}) 288 self._assert_ckpt(self._output_dir, False) 289 290 def test_evaluate(self): 291 with ops.Graph().as_default() as g, self.session(g): 292 _, _, out = self._build_inference_graph() 293 writer = learn.graph_actions.get_summary_writer(self._output_dir) 294 self._assert_summaries(self._output_dir, writer, expected_session_logs=[]) 295 self._assert_ckpt(self._output_dir, False) 296 results = learn.graph_actions.evaluate( 297 g, 298 output_dir=self._output_dir, 299 checkpoint_path=None, 300 eval_dict={'a': out}, 301 max_steps=1) 302 self.assertEqual(({'a': 6.0}, 0), results) 303 self._assert_summaries( 304 self._output_dir, 305 writer, 306 expected_summaries={0: { 307 'a': 6.0 308 }}, 309 expected_session_logs=[]) 310 self._assert_ckpt(self._output_dir, False) 311 312 def test_evaluate_ready_for_local_init(self): 313 with ops.Graph().as_default() as g, self.session(g): 314 variables_lib.create_global_step() 315 v = variables.VariableV1(1.0) 316 variables.VariableV1( 317 v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False) 318 ready_for_local_init_op = variables.report_uninitialized_variables( 319 variables.global_variables()) 320 ops.add_to_collection(ops.GraphKeys.READY_FOR_LOCAL_INIT_OP, 321 ready_for_local_init_op) 322 _ = learn.graph_actions.evaluate( 323 g, 324 output_dir=self._output_dir, 325 checkpoint_path=None, 326 eval_dict={'a': v}, 327 max_steps=1) 328 329 def test_evaluate_feed_fn(self): 330 with ops.Graph().as_default() as g, self.session(g): 331 in0, _, out = self._build_inference_graph() 332 writer = learn.graph_actions.get_summary_writer(self._output_dir) 333 self._assert_summaries(self._output_dir, writer, expected_session_logs=[]) 334 self._assert_ckpt(self._output_dir, False) 335 feeder = _Feeder(in0, 3) 336 results = learn.graph_actions.evaluate( 337 g, 338 output_dir=self._output_dir, 339 checkpoint_path=None, 340 eval_dict={'a': out}, 341 feed_fn=feeder.feed_fn, 342 max_steps=3) 343 self.assertEqual(3, feeder.step) 344 self.assertEqual(({'a': 25.0}, 0), results) 345 self._assert_summaries( 346 self._output_dir, 347 writer, 348 expected_summaries={0: { 349 'a': 25.0 350 }}, 351 expected_session_logs=[]) 352 self._assert_ckpt(self._output_dir, False) 353 354 def test_evaluate_feed_fn_with_exhaustion(self): 355 with ops.Graph().as_default() as g, self.session(g): 356 in0, _, out = self._build_inference_graph() 357 writer = learn.graph_actions.get_summary_writer(self._output_dir) 358 self._assert_summaries(self._output_dir, writer, expected_session_logs=[]) 359 feeder = _Feeder(in0, 2) 360 results = learn.graph_actions.evaluate( 361 g, 362 output_dir=self._output_dir, 363 checkpoint_path=None, 364 eval_dict={'a': out}, 365 feed_fn=feeder.feed_fn, 366 max_steps=3) 367 self.assertEqual(2, feeder.step) 368 self.assertEqual(({'a': 15.0}, 0), results) 369 self._assert_summaries( 370 self._output_dir, 371 writer, 372 expected_summaries={0: { 373 'a': 15.0 374 }}, 375 expected_session_logs=[]) 376 377 def test_evaluate_with_saver(self): 378 with ops.Graph().as_default() as g, self.session(g): 379 _, _, out = self._build_inference_graph() 380 ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver()) 381 writer = learn.graph_actions.get_summary_writer(self._output_dir) 382 self._assert_summaries(self._output_dir, writer, expected_session_logs=[]) 383 results = learn.graph_actions.evaluate( 384 g, 385 output_dir=self._output_dir, 386 checkpoint_path=None, 387 eval_dict={'a': out}, 388 max_steps=1) 389 self.assertEqual(({'a': 6.0}, 0), results) 390 self._assert_summaries( 391 self._output_dir, 392 writer, 393 expected_summaries={0: { 394 'a': 6.0 395 }}, 396 expected_session_logs=[]) 397 398 # TODO(ptucker): Resume training from previous ckpt. 399 # TODO(ptucker): !supervisor_is_chief 400 # TODO(ptucker): Custom init op for training. 401 # TODO(ptucker): Mock supervisor, and assert all interactions. 402 403 404# TODO(ispir): remove following tests after deprecated train. 405class GraphActionsTrainTest(test.TestCase): 406 """Tests for train.""" 407 408 def setUp(self): 409 learn.graph_actions.clear_summary_writers() 410 self._output_dir = tempfile.mkdtemp() 411 testing.FakeSummaryWriter.install() 412 413 def tearDown(self): 414 testing.FakeSummaryWriter.uninstall() 415 if self._output_dir: 416 shutil.rmtree(self._output_dir) 417 learn.graph_actions.clear_summary_writers() 418 419 def _assert_summaries(self, 420 output_dir, 421 expected_summaries=None, 422 expected_graphs=None, 423 expected_meta_graphs=None, 424 expected_session_logs=None): 425 writer = learn.graph_actions.get_summary_writer(output_dir) 426 self.assertTrue(isinstance(writer, testing.FakeSummaryWriter)) 427 writer.assert_summaries( 428 self, 429 expected_logdir=output_dir, 430 expected_graph=ops.get_default_graph(), 431 expected_summaries=expected_summaries, 432 expected_added_graphs=expected_graphs, 433 expected_added_meta_graphs=expected_meta_graphs, 434 expected_session_logs=expected_session_logs) 435 436 # TODO(ptucker): Test number and contents of checkpoint files. 437 def _assert_ckpt(self, output_dir, expected=True): 438 ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) 439 if expected: 440 pattern = '%s/model.ckpt-.*' % output_dir 441 primary_ckpt_path = ckpt_state.model_checkpoint_path 442 self.assertRegexpMatches(primary_ckpt_path, pattern) 443 all_ckpt_paths = ckpt_state.all_model_checkpoint_paths 444 self.assertTrue(primary_ckpt_path in all_ckpt_paths) 445 for ckpt_path in all_ckpt_paths: 446 self.assertRegexpMatches(ckpt_path, pattern) 447 else: 448 self.assertTrue(ckpt_state is None) 449 450 def _build_inference_graph(self): 451 """Build simple inference graph. 452 453 This includes a regular variable, local variable, and fake table. 454 455 Returns: 456 Tuple of 3 `Tensor` objects, 2 input and 1 output. 457 """ 458 variables_lib.create_global_step() 459 in0 = variables.VariableV1(1.0) 460 in1 = variables_lib.local_variable(2.0) 461 fake_table = variables.VariableV1( 462 3.0, 463 trainable=False, 464 collections=['fake_tables'], 465 name='fake_table_var') 466 in0.graph.add_to_collections([ops.GraphKeys.TABLE_INITIALIZERS], 467 fake_table.initializer) 468 out = in0 + in1 + fake_table 469 return in0, in1, out 470 471 def test_train_invalid_args(self): 472 with ops.Graph().as_default() as g, self.session(g): 473 train_op = constant_op.constant(1.0) 474 loss_op = constant_op.constant(2.0) 475 with self.assertRaisesRegexp(ValueError, 'utput directory'): 476 learn.graph_actions.train( 477 g, output_dir=None, train_op=train_op, loss_op=loss_op) 478 with self.assertRaisesRegexp(ValueError, 'utput directory'): 479 learn.graph_actions.train( 480 g, 481 output_dir='', 482 train_op=constant_op.constant(1.0), 483 loss_op=constant_op.constant(2.0)) 484 with self.assertRaisesRegexp(ValueError, 'train_op'): 485 learn.graph_actions.train( 486 g, output_dir=self._output_dir, train_op=None, loss_op=loss_op) 487 with self.assertRaisesRegexp(ValueError, 'loss_op'): 488 learn.graph_actions.train( 489 g, 490 output_dir=self._output_dir, 491 train_op=constant_op.constant(1.0), 492 loss_op=None) 493 with self.assertRaisesRegexp(ValueError, 'global_step'): 494 learn.graph_actions.train( 495 g, 496 output_dir=self._output_dir, 497 train_op=constant_op.constant(1.0), 498 loss_op=loss_op) 499 500 # TODO(ptucker): Resume training from previous ckpt. 501 # TODO(ptucker): !supervisor_is_chief 502 # TODO(ptucker): Custom init op for training. 503 # TODO(ptucker): Mock supervisor, and assert all interactions. 504 505 def test_train(self): 506 with ops.Graph().as_default() as g, self.session(g): 507 with ops.control_dependencies(self._build_inference_graph()): 508 train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) 509 self._assert_summaries(self._output_dir) 510 self._assert_ckpt(self._output_dir, False) 511 loss = learn.graph_actions.train( 512 g, 513 output_dir=self._output_dir, 514 train_op=train_op, 515 loss_op=constant_op.constant(2.0), 516 steps=1) 517 # TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the 518 # SaverDef, so we can't add it to the summary assertion test below. 519 # meta_graph_def = meta_graph.create_meta_graph_def() 520 self.assertEqual(2.0, loss) 521 self._assert_summaries(self._output_dir, expected_graphs=[g]) 522 self._assert_ckpt(self._output_dir, True) 523 524 def test_train_steps_is_incremental(self): 525 with ops.Graph().as_default() as g, self.session(g): 526 with ops.control_dependencies(self._build_inference_graph()): 527 train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) 528 learn.graph_actions.train( 529 g, 530 output_dir=self._output_dir, 531 train_op=train_op, 532 loss_op=constant_op.constant(2.0), 533 steps=10) 534 step = checkpoint_utils.load_variable( 535 self._output_dir, variables_lib.get_global_step().name) 536 self.assertEqual(10, step) 537 538 with ops.Graph().as_default() as g, self.session(g): 539 with ops.control_dependencies(self._build_inference_graph()): 540 train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) 541 learn.graph_actions.train( 542 g, 543 output_dir=self._output_dir, 544 train_op=train_op, 545 loss_op=constant_op.constant(2.0), 546 steps=15) 547 step = checkpoint_utils.load_variable( 548 self._output_dir, variables_lib.get_global_step().name) 549 self.assertEqual(25, step) 550 551 def test_train_max_steps_is_not_incremental(self): 552 with ops.Graph().as_default() as g, self.session(g): 553 with ops.control_dependencies(self._build_inference_graph()): 554 train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) 555 learn.graph_actions.train( 556 g, 557 output_dir=self._output_dir, 558 train_op=train_op, 559 loss_op=constant_op.constant(2.0), 560 max_steps=10) 561 step = checkpoint_utils.load_variable( 562 self._output_dir, variables_lib.get_global_step().name) 563 self.assertEqual(10, step) 564 565 with ops.Graph().as_default() as g, self.session(g): 566 with ops.control_dependencies(self._build_inference_graph()): 567 train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) 568 learn.graph_actions.train( 569 g, 570 output_dir=self._output_dir, 571 train_op=train_op, 572 loss_op=constant_op.constant(2.0), 573 max_steps=15) 574 step = checkpoint_utils.load_variable( 575 self._output_dir, variables_lib.get_global_step().name) 576 self.assertEqual(15, step) 577 578 def test_train_loss(self): 579 with ops.Graph().as_default() as g, self.session(g): 580 variables_lib.create_global_step() 581 loss_var = variables_lib.local_variable(10.0) 582 train_op = control_flow_ops.group( 583 state_ops.assign_add(variables_lib.get_global_step(), 1), 584 state_ops.assign_add(loss_var, -1.0)) 585 self._assert_summaries(self._output_dir) 586 self._assert_ckpt(self._output_dir, False) 587 loss = learn.graph_actions.train( 588 g, 589 output_dir=self._output_dir, 590 train_op=train_op, 591 loss_op=loss_var.value(), 592 steps=6) 593 # TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the 594 # SaverDef, so we can't add it to the summary assertion test below. 595 # meta_graph_def = meta_graph.create_meta_graph_def() 596 self.assertEqual(4.0, loss) 597 self._assert_summaries(self._output_dir, expected_graphs=[g]) 598 self._assert_ckpt(self._output_dir, True) 599 600 def test_train_summaries(self): 601 with ops.Graph().as_default() as g, self.session(g): 602 with ops.control_dependencies(self._build_inference_graph()): 603 train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) 604 loss_op = constant_op.constant(2.0) 605 summary.scalar('loss', loss_op) 606 self._assert_summaries(self._output_dir) 607 self._assert_ckpt(self._output_dir, False) 608 loss = learn.graph_actions.train( 609 g, 610 output_dir=self._output_dir, 611 train_op=train_op, 612 loss_op=loss_op, 613 steps=1) 614 # TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the 615 # SaverDef, so we can't add it to the summary assertion test below. 616 # meta_graph_def = meta_graph.create_meta_graph_def() 617 self.assertEqual(2.0, loss) 618 self._assert_summaries( 619 self._output_dir, 620 expected_graphs=[g], 621 expected_summaries={1: { 622 'loss': 2.0 623 }}) 624 self._assert_ckpt(self._output_dir, True) 625 626 def test_train_chief_monitor(self): 627 with ops.Graph().as_default() as g, self.session(g): 628 with ops.control_dependencies(self._build_inference_graph()): 629 train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) 630 loss_op = constant_op.constant(2.0) 631 summary.scalar('loss', loss_op) 632 chief_exclusive_monitor = _BaseMonitorWrapper(False) 633 all_workers_monitor = _BaseMonitorWrapper(True) 634 loss = learn.graph_actions.train( 635 g, 636 output_dir=self._output_dir, 637 train_op=train_op, 638 loss_op=loss_op, 639 supervisor_is_chief=True, 640 steps=1, 641 monitors=[chief_exclusive_monitor, all_workers_monitor]) 642 self.assertEqual(2.0, loss) 643 self.assertTrue(chief_exclusive_monitor.is_active and 644 all_workers_monitor.is_active, 645 'All monitors must have been active.') 646 self.assertTrue(chief_exclusive_monitor.has_step and 647 all_workers_monitor.has_step, 648 'All monitors must have a step.') 649 650 def test_train_worker_monitor(self): 651 # We need to explicitly set device due to check on non-chief workers 652 # requiring all variables to have a device assigned. 653 with ops.Graph().as_default() as g, g.device('/cpu:0'): 654 global_step = variables_lib.create_global_step(g) 655 train_op = state_ops.assign_add(global_step, 1) 656 loss_op = constant_op.constant(2.0) 657 summary.scalar('loss', loss_op) 658 # Add explicit "local" init op to initialize all variables 659 # as there's no chief to init here. 660 init_op = variables.global_variables_initializer() 661 ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, init_op) 662 # Create worker monitors where one should be active on the worker 663 # and the other chief exclusive. 664 chief_exclusive_monitor = _BaseMonitorWrapper(False) 665 all_workers_monitor = _BaseMonitorWrapper(True) 666 with self.session(g): 667 loss = learn.graph_actions.train( 668 g, 669 output_dir=self._output_dir, 670 global_step_tensor=global_step, 671 train_op=train_op, 672 loss_op=loss_op, 673 supervisor_is_chief=False, 674 steps=1, 675 monitors=[chief_exclusive_monitor, all_workers_monitor]) 676 self.assertEqual(2.0, loss) 677 self.assertTrue(not chief_exclusive_monitor.is_active and 678 all_workers_monitor.is_active, 679 'Only non-chief runnable monitor must have been active.') 680 self.assertTrue(not chief_exclusive_monitor.has_step and 681 all_workers_monitor.has_step, 682 'Only non-chief runnable monitor must have a step.') 683 684 685if __name__ == '__main__': 686 test.main() 687