• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Graph actions tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import shutil
22import tempfile
23
24from tensorflow.contrib import testing
25from tensorflow.contrib.framework.python.framework import checkpoint_utils
26from tensorflow.contrib.framework.python.ops import variables as variables_lib
27from tensorflow.contrib.learn.python import learn
28from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import test_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import resources
34from tensorflow.python.ops import state_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import test
37from tensorflow.python.summary import summary
38from tensorflow.python.training import checkpoint_management
39from tensorflow.python.training import saver as saver_lib
40
41
42class _Feeder(object):
43  """Simple generator for `feed_fn`, returning 10 * step."""
44
45  def __init__(self, tensor, max_step):
46    self._step = 0
47    self._tensor = tensor
48    self._max_step = max_step
49
50  @property
51  def step(self):
52    return self._step
53
54  def feed_fn(self):
55    if self._step >= self._max_step:
56      raise StopIteration
57    value = self._step * 10.0
58    self._step += 1
59    return {self._tensor: value}
60
61
62class _BaseMonitorWrapper(BaseMonitor):
63  """Base monitor wrapper to facilitate testing.
64
65  This monitor can act as either chief-exclusive or non-exclusive.
66  """
67
68  def __init__(self, run_on_all_workers):
69    super(_BaseMonitorWrapper, self).__init__()
70    self._run_on_all_workers = run_on_all_workers
71    self._is_active = False
72    self._has_step = False
73
74  @property
75  def run_on_all_workers(self):
76    return self._run_on_all_workers
77
78  @property
79  def is_active(self):
80    return self._is_active
81
82  @property
83  def has_step(self):
84    return self._has_step
85
86  def begin(self, max_steps=None):
87    self._is_active = True
88    return super(_BaseMonitorWrapper, self).begin(max_steps)
89
90  def step_begin(self, step):
91    self._has_step = True
92    return super(_BaseMonitorWrapper, self).step_begin(step)
93
94
95class GraphActionsTest(test.TestCase):
96  """Graph actions tests."""
97
98  def setUp(self):
99    learn.graph_actions.clear_summary_writers()
100    self._output_dir = tempfile.mkdtemp()
101    testing.FakeSummaryWriter.install()
102
103  def tearDown(self):
104    testing.FakeSummaryWriter.uninstall()
105    if self._output_dir:
106      shutil.rmtree(self._output_dir)
107    learn.graph_actions.clear_summary_writers()
108
109  def _assert_summaries(self,
110                        output_dir,
111                        writer,
112                        expected_summaries=None,
113                        expected_graphs=None,
114                        expected_meta_graphs=None,
115                        expected_session_logs=None):
116    self.assertTrue(isinstance(writer, testing.FakeSummaryWriter))
117    writer.assert_summaries(
118        self,
119        expected_logdir=output_dir,
120        expected_graph=ops.get_default_graph(),
121        expected_summaries=expected_summaries,
122        expected_added_graphs=expected_graphs,
123        expected_added_meta_graphs=expected_meta_graphs,
124        expected_session_logs=expected_session_logs)
125
126  # TODO(ptucker): Test number and contents of checkpoint files.
127  def _assert_ckpt(self, output_dir, expected=True):
128    ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
129    if expected:
130      pattern = '%s/model.ckpt-.*' % output_dir
131      primary_ckpt_path = ckpt_state.model_checkpoint_path
132      self.assertRegexpMatches(primary_ckpt_path, pattern)
133      all_ckpt_paths = ckpt_state.all_model_checkpoint_paths
134      self.assertTrue(primary_ckpt_path in all_ckpt_paths)
135      for ckpt_path in all_ckpt_paths:
136        self.assertRegexpMatches(ckpt_path, pattern)
137    else:
138      self.assertTrue(ckpt_state is None)
139
140  # TODO(ptucker): Test lock, multi-threaded access?
141  def test_summary_writer(self):
142    writer = learn.graph_actions.get_summary_writer('log/dir/0')
143    self._assert_summaries('log/dir/0', writer)
144    self.assertTrue(
145        learn.graph_actions.get_summary_writer('log/dir/0') is
146        learn.graph_actions.get_summary_writer('log/dir/0'))
147    self.assertTrue(
148        learn.graph_actions.get_summary_writer('log/dir/0') is
149        not learn.graph_actions.get_summary_writer('log/dir/1'))
150
151  # TODO(ptucker): Test restore_checkpoint_path for eval; this should obsolete
152  # test_evaluate_with_saver().
153  # TODO(ptucker): Test start_queue_runners for both eval & train.
154  # TODO(ptucker): Test coord.request_stop & coord.join for eval.
155
156  def _build_inference_graph(self):
157    """Build simple inference graph.
158
159    This includes a regular variable, local variable, and fake table.
160
161    Returns:
162      Tuple of 3 `Tensor` objects, 2 input and 1 output.
163    """
164    variables_lib.create_global_step()
165    in0 = variables.VariableV1(1.0)
166    in1 = variables_lib.local_variable(2.0)
167    fake_table = variables.VariableV1(
168        3.0,
169        trainable=False,
170        collections=['fake_tables'],
171        name='fake_table_var')
172    in0.graph.add_to_collections([ops.GraphKeys.TABLE_INITIALIZERS],
173                                 fake_table.initializer)
174    out = in0 + in1 + fake_table
175    return in0, in1, out
176
177  def test_infer(self):
178    with ops.Graph().as_default() as g, self.session(g):
179      self._assert_ckpt(self._output_dir, False)
180      in0, in1, out = self._build_inference_graph()
181      self.assertEqual({
182          'a': 1.0,
183          'b': 2.0,
184          'c': 6.0
185      }, learn.graph_actions.infer(None, {'a': in0,
186                                          'b': in1,
187                                          'c': out}))
188      self._assert_ckpt(self._output_dir, False)
189
190  @test.mock.patch.object(
191      learn.graph_actions.coordinator.Coordinator,
192      'request_stop',
193      side_effect=learn.graph_actions.coordinator.Coordinator.request_stop,
194      autospec=True)
195  def test_coordinator_request_stop_called(self, request_stop):
196    with ops.Graph().as_default() as g, self.session(g):
197      in0, in1, out = self._build_inference_graph()
198      learn.graph_actions.infer(None, {'a': in0, 'b': in1, 'c': out})
199      self.assertTrue(request_stop.called)
200
201  @test.mock.patch.object(
202      learn.graph_actions.coordinator.Coordinator,
203      'request_stop',
204      side_effect=learn.graph_actions.coordinator.Coordinator.request_stop,
205      autospec=True)
206  def test_run_feeds_iter_cleanup_with_exceptions(self, request_stop):
207    with ops.Graph().as_default() as g, self.session(g):
208      in0, in1, out = self._build_inference_graph()
209      try:
210        for _ in learn.graph_actions.run_feeds_iter({
211            'a': in0,
212            'b': in1,
213            'c': out
214        }, [None] * 3):
215          self.assertFalse(request_stop.called)
216          raise ValueError('Fake exception')
217      except ValueError:
218        pass
219      self.assertTrue(request_stop.called)
220
221  def test_run_feeds_iter_calls_resources_init(self):
222    with ops.Graph().as_default():
223      in0, _, _ = self._build_inference_graph()
224      handle = test_ops.stub_resource_handle_op(container='a', shared_name='b')
225      resources.register_resource(
226          handle=handle,
227          create_op=test_ops.resource_create_op(handle),
228          is_initialized_op=test_ops.resource_initialized_op(handle))
229
230      for _ in learn.graph_actions.run_feeds_iter(
231          {
232              'in0': in0
233          }, feed_dicts=[{}]):
234        self.assertTrue(test_ops.resource_initialized_op(handle).eval())
235
236  def test_infer_different_default_graph(self):
237    with self.cached_session():
238      self._assert_ckpt(self._output_dir, False)
239      with ops.Graph().as_default():
240        in0, in1, out = self._build_inference_graph()
241      with ops.Graph().as_default():
242        self.assertEqual({
243            'a': 1.0,
244            'b': 2.0,
245            'c': 6.0
246        }, learn.graph_actions.infer(None, {'a': in0,
247                                            'b': in1,
248                                            'c': out}))
249      self._assert_ckpt(self._output_dir, False)
250
251  def test_infer_invalid_feed(self):
252    with ops.Graph().as_default() as g, self.session(g):
253      self._assert_ckpt(self._output_dir, False)
254      in0, _, _ = self._build_inference_graph()
255      with self.assertRaisesRegexp(TypeError, 'Can not convert a NoneType'):
256        learn.graph_actions.infer(None, {'a': in0}, feed_dict={None: 4.0})
257      self._assert_ckpt(self._output_dir, False)
258
259  def test_infer_feed(self):
260    with ops.Graph().as_default() as g, self.session(g):
261      self._assert_ckpt(self._output_dir, False)
262      in0, _, out = self._build_inference_graph()
263      self.assertEqual(
264          {
265              'c': 9.0
266          },
267          learn.graph_actions.infer(
268              None, {'c': out}, feed_dict={in0: 4.0}))
269      self._assert_ckpt(self._output_dir, False)
270
271  # TODO(ptucker): Test eval for 1 epoch.
272
273  def test_evaluate_invalid_args(self):
274    with ops.Graph().as_default() as g, self.session(g):
275      self._assert_ckpt(self._output_dir, False)
276      with self.assertRaisesRegexp(ValueError, 'utput directory'):
277        learn.graph_actions.evaluate(
278            g,
279            output_dir=None,
280            checkpoint_path=None,
281            eval_dict={'a': constant_op.constant(1.0)})
282      with self.assertRaisesRegexp(ValueError, 'utput directory'):
283        learn.graph_actions.evaluate(
284            g,
285            output_dir='',
286            checkpoint_path=None,
287            eval_dict={'a': constant_op.constant(1.0)})
288      self._assert_ckpt(self._output_dir, False)
289
290  def test_evaluate(self):
291    with ops.Graph().as_default() as g, self.session(g):
292      _, _, out = self._build_inference_graph()
293      writer = learn.graph_actions.get_summary_writer(self._output_dir)
294      self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
295      self._assert_ckpt(self._output_dir, False)
296      results = learn.graph_actions.evaluate(
297          g,
298          output_dir=self._output_dir,
299          checkpoint_path=None,
300          eval_dict={'a': out},
301          max_steps=1)
302      self.assertEqual(({'a': 6.0}, 0), results)
303      self._assert_summaries(
304          self._output_dir,
305          writer,
306          expected_summaries={0: {
307              'a': 6.0
308          }},
309          expected_session_logs=[])
310      self._assert_ckpt(self._output_dir, False)
311
312  def test_evaluate_ready_for_local_init(self):
313    with ops.Graph().as_default() as g, self.session(g):
314      variables_lib.create_global_step()
315      v = variables.VariableV1(1.0)
316      variables.VariableV1(
317          v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False)
318      ready_for_local_init_op = variables.report_uninitialized_variables(
319          variables.global_variables())
320      ops.add_to_collection(ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
321                            ready_for_local_init_op)
322      _ = learn.graph_actions.evaluate(
323          g,
324          output_dir=self._output_dir,
325          checkpoint_path=None,
326          eval_dict={'a': v},
327          max_steps=1)
328
329  def test_evaluate_feed_fn(self):
330    with ops.Graph().as_default() as g, self.session(g):
331      in0, _, out = self._build_inference_graph()
332      writer = learn.graph_actions.get_summary_writer(self._output_dir)
333      self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
334      self._assert_ckpt(self._output_dir, False)
335      feeder = _Feeder(in0, 3)
336      results = learn.graph_actions.evaluate(
337          g,
338          output_dir=self._output_dir,
339          checkpoint_path=None,
340          eval_dict={'a': out},
341          feed_fn=feeder.feed_fn,
342          max_steps=3)
343      self.assertEqual(3, feeder.step)
344      self.assertEqual(({'a': 25.0}, 0), results)
345      self._assert_summaries(
346          self._output_dir,
347          writer,
348          expected_summaries={0: {
349              'a': 25.0
350          }},
351          expected_session_logs=[])
352      self._assert_ckpt(self._output_dir, False)
353
354  def test_evaluate_feed_fn_with_exhaustion(self):
355    with ops.Graph().as_default() as g, self.session(g):
356      in0, _, out = self._build_inference_graph()
357      writer = learn.graph_actions.get_summary_writer(self._output_dir)
358      self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
359      feeder = _Feeder(in0, 2)
360      results = learn.graph_actions.evaluate(
361          g,
362          output_dir=self._output_dir,
363          checkpoint_path=None,
364          eval_dict={'a': out},
365          feed_fn=feeder.feed_fn,
366          max_steps=3)
367      self.assertEqual(2, feeder.step)
368      self.assertEqual(({'a': 15.0}, 0), results)
369      self._assert_summaries(
370          self._output_dir,
371          writer,
372          expected_summaries={0: {
373              'a': 15.0
374          }},
375          expected_session_logs=[])
376
377  def test_evaluate_with_saver(self):
378    with ops.Graph().as_default() as g, self.session(g):
379      _, _, out = self._build_inference_graph()
380      ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
381      writer = learn.graph_actions.get_summary_writer(self._output_dir)
382      self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
383      results = learn.graph_actions.evaluate(
384          g,
385          output_dir=self._output_dir,
386          checkpoint_path=None,
387          eval_dict={'a': out},
388          max_steps=1)
389      self.assertEqual(({'a': 6.0}, 0), results)
390      self._assert_summaries(
391          self._output_dir,
392          writer,
393          expected_summaries={0: {
394              'a': 6.0
395          }},
396          expected_session_logs=[])
397
398  # TODO(ptucker): Resume training from previous ckpt.
399  # TODO(ptucker): !supervisor_is_chief
400  # TODO(ptucker): Custom init op for training.
401  # TODO(ptucker): Mock supervisor, and assert all interactions.
402
403
404# TODO(ispir): remove following tests after deprecated train.
405class GraphActionsTrainTest(test.TestCase):
406  """Tests for train."""
407
408  def setUp(self):
409    learn.graph_actions.clear_summary_writers()
410    self._output_dir = tempfile.mkdtemp()
411    testing.FakeSummaryWriter.install()
412
413  def tearDown(self):
414    testing.FakeSummaryWriter.uninstall()
415    if self._output_dir:
416      shutil.rmtree(self._output_dir)
417    learn.graph_actions.clear_summary_writers()
418
419  def _assert_summaries(self,
420                        output_dir,
421                        expected_summaries=None,
422                        expected_graphs=None,
423                        expected_meta_graphs=None,
424                        expected_session_logs=None):
425    writer = learn.graph_actions.get_summary_writer(output_dir)
426    self.assertTrue(isinstance(writer, testing.FakeSummaryWriter))
427    writer.assert_summaries(
428        self,
429        expected_logdir=output_dir,
430        expected_graph=ops.get_default_graph(),
431        expected_summaries=expected_summaries,
432        expected_added_graphs=expected_graphs,
433        expected_added_meta_graphs=expected_meta_graphs,
434        expected_session_logs=expected_session_logs)
435
436  # TODO(ptucker): Test number and contents of checkpoint files.
437  def _assert_ckpt(self, output_dir, expected=True):
438    ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
439    if expected:
440      pattern = '%s/model.ckpt-.*' % output_dir
441      primary_ckpt_path = ckpt_state.model_checkpoint_path
442      self.assertRegexpMatches(primary_ckpt_path, pattern)
443      all_ckpt_paths = ckpt_state.all_model_checkpoint_paths
444      self.assertTrue(primary_ckpt_path in all_ckpt_paths)
445      for ckpt_path in all_ckpt_paths:
446        self.assertRegexpMatches(ckpt_path, pattern)
447    else:
448      self.assertTrue(ckpt_state is None)
449
450  def _build_inference_graph(self):
451    """Build simple inference graph.
452
453    This includes a regular variable, local variable, and fake table.
454
455    Returns:
456      Tuple of 3 `Tensor` objects, 2 input and 1 output.
457    """
458    variables_lib.create_global_step()
459    in0 = variables.VariableV1(1.0)
460    in1 = variables_lib.local_variable(2.0)
461    fake_table = variables.VariableV1(
462        3.0,
463        trainable=False,
464        collections=['fake_tables'],
465        name='fake_table_var')
466    in0.graph.add_to_collections([ops.GraphKeys.TABLE_INITIALIZERS],
467                                 fake_table.initializer)
468    out = in0 + in1 + fake_table
469    return in0, in1, out
470
471  def test_train_invalid_args(self):
472    with ops.Graph().as_default() as g, self.session(g):
473      train_op = constant_op.constant(1.0)
474      loss_op = constant_op.constant(2.0)
475      with self.assertRaisesRegexp(ValueError, 'utput directory'):
476        learn.graph_actions.train(
477            g, output_dir=None, train_op=train_op, loss_op=loss_op)
478      with self.assertRaisesRegexp(ValueError, 'utput directory'):
479        learn.graph_actions.train(
480            g,
481            output_dir='',
482            train_op=constant_op.constant(1.0),
483            loss_op=constant_op.constant(2.0))
484      with self.assertRaisesRegexp(ValueError, 'train_op'):
485        learn.graph_actions.train(
486            g, output_dir=self._output_dir, train_op=None, loss_op=loss_op)
487      with self.assertRaisesRegexp(ValueError, 'loss_op'):
488        learn.graph_actions.train(
489            g,
490            output_dir=self._output_dir,
491            train_op=constant_op.constant(1.0),
492            loss_op=None)
493      with self.assertRaisesRegexp(ValueError, 'global_step'):
494        learn.graph_actions.train(
495            g,
496            output_dir=self._output_dir,
497            train_op=constant_op.constant(1.0),
498            loss_op=loss_op)
499
500  # TODO(ptucker): Resume training from previous ckpt.
501  # TODO(ptucker): !supervisor_is_chief
502  # TODO(ptucker): Custom init op for training.
503  # TODO(ptucker): Mock supervisor, and assert all interactions.
504
505  def test_train(self):
506    with ops.Graph().as_default() as g, self.session(g):
507      with ops.control_dependencies(self._build_inference_graph()):
508        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
509      self._assert_summaries(self._output_dir)
510      self._assert_ckpt(self._output_dir, False)
511      loss = learn.graph_actions.train(
512          g,
513          output_dir=self._output_dir,
514          train_op=train_op,
515          loss_op=constant_op.constant(2.0),
516          steps=1)
517      # TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the
518      # SaverDef, so we can't add it to the summary assertion test below.
519      # meta_graph_def = meta_graph.create_meta_graph_def()
520      self.assertEqual(2.0, loss)
521      self._assert_summaries(self._output_dir, expected_graphs=[g])
522      self._assert_ckpt(self._output_dir, True)
523
524  def test_train_steps_is_incremental(self):
525    with ops.Graph().as_default() as g, self.session(g):
526      with ops.control_dependencies(self._build_inference_graph()):
527        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
528      learn.graph_actions.train(
529          g,
530          output_dir=self._output_dir,
531          train_op=train_op,
532          loss_op=constant_op.constant(2.0),
533          steps=10)
534      step = checkpoint_utils.load_variable(
535          self._output_dir, variables_lib.get_global_step().name)
536      self.assertEqual(10, step)
537
538    with ops.Graph().as_default() as g, self.session(g):
539      with ops.control_dependencies(self._build_inference_graph()):
540        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
541      learn.graph_actions.train(
542          g,
543          output_dir=self._output_dir,
544          train_op=train_op,
545          loss_op=constant_op.constant(2.0),
546          steps=15)
547      step = checkpoint_utils.load_variable(
548          self._output_dir, variables_lib.get_global_step().name)
549      self.assertEqual(25, step)
550
551  def test_train_max_steps_is_not_incremental(self):
552    with ops.Graph().as_default() as g, self.session(g):
553      with ops.control_dependencies(self._build_inference_graph()):
554        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
555      learn.graph_actions.train(
556          g,
557          output_dir=self._output_dir,
558          train_op=train_op,
559          loss_op=constant_op.constant(2.0),
560          max_steps=10)
561      step = checkpoint_utils.load_variable(
562          self._output_dir, variables_lib.get_global_step().name)
563      self.assertEqual(10, step)
564
565    with ops.Graph().as_default() as g, self.session(g):
566      with ops.control_dependencies(self._build_inference_graph()):
567        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
568      learn.graph_actions.train(
569          g,
570          output_dir=self._output_dir,
571          train_op=train_op,
572          loss_op=constant_op.constant(2.0),
573          max_steps=15)
574      step = checkpoint_utils.load_variable(
575          self._output_dir, variables_lib.get_global_step().name)
576      self.assertEqual(15, step)
577
578  def test_train_loss(self):
579    with ops.Graph().as_default() as g, self.session(g):
580      variables_lib.create_global_step()
581      loss_var = variables_lib.local_variable(10.0)
582      train_op = control_flow_ops.group(
583          state_ops.assign_add(variables_lib.get_global_step(), 1),
584          state_ops.assign_add(loss_var, -1.0))
585      self._assert_summaries(self._output_dir)
586      self._assert_ckpt(self._output_dir, False)
587      loss = learn.graph_actions.train(
588          g,
589          output_dir=self._output_dir,
590          train_op=train_op,
591          loss_op=loss_var.value(),
592          steps=6)
593      # TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the
594      # SaverDef, so we can't add it to the summary assertion test below.
595      # meta_graph_def = meta_graph.create_meta_graph_def()
596      self.assertEqual(4.0, loss)
597      self._assert_summaries(self._output_dir, expected_graphs=[g])
598      self._assert_ckpt(self._output_dir, True)
599
600  def test_train_summaries(self):
601    with ops.Graph().as_default() as g, self.session(g):
602      with ops.control_dependencies(self._build_inference_graph()):
603        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
604      loss_op = constant_op.constant(2.0)
605      summary.scalar('loss', loss_op)
606      self._assert_summaries(self._output_dir)
607      self._assert_ckpt(self._output_dir, False)
608      loss = learn.graph_actions.train(
609          g,
610          output_dir=self._output_dir,
611          train_op=train_op,
612          loss_op=loss_op,
613          steps=1)
614      # TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the
615      # SaverDef, so we can't add it to the summary assertion test below.
616      # meta_graph_def = meta_graph.create_meta_graph_def()
617      self.assertEqual(2.0, loss)
618      self._assert_summaries(
619          self._output_dir,
620          expected_graphs=[g],
621          expected_summaries={1: {
622              'loss': 2.0
623          }})
624      self._assert_ckpt(self._output_dir, True)
625
626  def test_train_chief_monitor(self):
627    with ops.Graph().as_default() as g, self.session(g):
628      with ops.control_dependencies(self._build_inference_graph()):
629        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
630      loss_op = constant_op.constant(2.0)
631      summary.scalar('loss', loss_op)
632      chief_exclusive_monitor = _BaseMonitorWrapper(False)
633      all_workers_monitor = _BaseMonitorWrapper(True)
634      loss = learn.graph_actions.train(
635          g,
636          output_dir=self._output_dir,
637          train_op=train_op,
638          loss_op=loss_op,
639          supervisor_is_chief=True,
640          steps=1,
641          monitors=[chief_exclusive_monitor, all_workers_monitor])
642      self.assertEqual(2.0, loss)
643      self.assertTrue(chief_exclusive_monitor.is_active and
644                      all_workers_monitor.is_active,
645                      'All monitors must have been active.')
646      self.assertTrue(chief_exclusive_monitor.has_step and
647                      all_workers_monitor.has_step,
648                      'All monitors must have a step.')
649
650  def test_train_worker_monitor(self):
651    # We need to explicitly set device due to check on non-chief workers
652    # requiring all variables to have a device assigned.
653    with ops.Graph().as_default() as g, g.device('/cpu:0'):
654      global_step = variables_lib.create_global_step(g)
655      train_op = state_ops.assign_add(global_step, 1)
656      loss_op = constant_op.constant(2.0)
657      summary.scalar('loss', loss_op)
658      # Add explicit "local" init op to initialize all variables
659      # as there's no chief to init here.
660      init_op = variables.global_variables_initializer()
661      ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, init_op)
662      # Create worker monitors where one should be active on the worker
663      # and the other chief exclusive.
664      chief_exclusive_monitor = _BaseMonitorWrapper(False)
665      all_workers_monitor = _BaseMonitorWrapper(True)
666      with self.session(g):
667        loss = learn.graph_actions.train(
668            g,
669            output_dir=self._output_dir,
670            global_step_tensor=global_step,
671            train_op=train_op,
672            loss_op=loss_op,
673            supervisor_is_chief=False,
674            steps=1,
675            monitors=[chief_exclusive_monitor, all_workers_monitor])
676      self.assertEqual(2.0, loss)
677      self.assertTrue(not chief_exclusive_monitor.is_active and
678                      all_workers_monitor.is_active,
679                      'Only non-chief runnable monitor must have been active.')
680      self.assertTrue(not chief_exclusive_monitor.has_step and
681                      all_workers_monitor.has_step,
682                      'Only non-chief runnable monitor must have a step.')
683
684
685if __name__ == '__main__':
686  test.main()
687