• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for V2 summary ops from summary_ops_v2."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import unittest
23
24import six
25
26from tensorflow.core.framework import graph_pb2
27from tensorflow.core.framework import node_def_pb2
28from tensorflow.core.framework import step_stats_pb2
29from tensorflow.core.framework import summary_pb2
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.util import event_pb2
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.framework import test_util
41from tensorflow.python.lib.io import tf_record
42from tensorflow.python.module import module
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import summary_ops_v2 as summary_ops
45from tensorflow.python.ops import variables
46from tensorflow.python.platform import gfile
47from tensorflow.python.platform import test
48from tensorflow.python.platform import tf_logging as logging
49from tensorflow.python.saved_model import load as saved_model_load
50from tensorflow.python.saved_model import loader as saved_model_loader
51from tensorflow.python.saved_model import save as saved_model_save
52from tensorflow.python.saved_model import tag_constants
53
54
55class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
56
57  def testWrite(self):
58    logdir = self.get_temp_dir()
59    with context.eager_mode():
60      with summary_ops.create_file_writer_v2(logdir).as_default():
61        output = summary_ops.write('tag', 42, step=12)
62        self.assertTrue(output.numpy())
63    events = events_from_logdir(logdir)
64    self.assertEqual(2, len(events))
65    self.assertEqual(12, events[1].step)
66    value = events[1].summary.value[0]
67    self.assertEqual('tag', value.tag)
68    self.assertEqual(42, to_numpy(value))
69
70  def testWrite_fromFunction(self):
71    logdir = self.get_temp_dir()
72    with context.eager_mode():
73      writer = summary_ops.create_file_writer_v2(logdir)
74      @def_function.function
75      def f():
76        with writer.as_default():
77          return summary_ops.write('tag', 42, step=12)
78      output = f()
79      self.assertTrue(output.numpy())
80    events = events_from_logdir(logdir)
81    self.assertEqual(2, len(events))
82    self.assertEqual(12, events[1].step)
83    value = events[1].summary.value[0]
84    self.assertEqual('tag', value.tag)
85    self.assertEqual(42, to_numpy(value))
86
87  def testWrite_metadata(self):
88    logdir = self.get_temp_dir()
89    metadata = summary_pb2.SummaryMetadata()
90    metadata.plugin_data.plugin_name = 'foo'
91    with context.eager_mode():
92      with summary_ops.create_file_writer_v2(logdir).as_default():
93        summary_ops.write('obj', 0, 0, metadata=metadata)
94        summary_ops.write('bytes', 0, 0, metadata=metadata.SerializeToString())
95        m = constant_op.constant(metadata.SerializeToString())
96        summary_ops.write('string_tensor', 0, 0, metadata=m)
97    events = events_from_logdir(logdir)
98    self.assertEqual(4, len(events))
99    self.assertEqual(metadata, events[1].summary.value[0].metadata)
100    self.assertEqual(metadata, events[2].summary.value[0].metadata)
101    self.assertEqual(metadata, events[3].summary.value[0].metadata)
102
103  def testWrite_name(self):
104    @def_function.function
105    def f():
106      output = summary_ops.write('tag', 42, step=12, name='anonymous')
107      self.assertTrue(output.name.startswith('anonymous'))
108    f()
109
110  def testWrite_ndarray(self):
111    logdir = self.get_temp_dir()
112    with context.eager_mode():
113      with summary_ops.create_file_writer_v2(logdir).as_default():
114        summary_ops.write('tag', [[1, 2], [3, 4]], step=12)
115    events = events_from_logdir(logdir)
116    value = events[1].summary.value[0]
117    self.assertAllEqual([[1, 2], [3, 4]], to_numpy(value))
118
119  def testWrite_tensor(self):
120    logdir = self.get_temp_dir()
121    with context.eager_mode():
122      t = constant_op.constant([[1, 2], [3, 4]])
123      with summary_ops.create_file_writer_v2(logdir).as_default():
124        summary_ops.write('tag', t, step=12)
125      expected = t.numpy()
126    events = events_from_logdir(logdir)
127    value = events[1].summary.value[0]
128    self.assertAllEqual(expected, to_numpy(value))
129
130  def testWrite_tensor_fromFunction(self):
131    logdir = self.get_temp_dir()
132    with context.eager_mode():
133      writer = summary_ops.create_file_writer_v2(logdir)
134      @def_function.function
135      def f(t):
136        with writer.as_default():
137          summary_ops.write('tag', t, step=12)
138      t = constant_op.constant([[1, 2], [3, 4]])
139      f(t)
140      expected = t.numpy()
141    events = events_from_logdir(logdir)
142    value = events[1].summary.value[0]
143    self.assertAllEqual(expected, to_numpy(value))
144
145  def testWrite_stringTensor(self):
146    logdir = self.get_temp_dir()
147    with context.eager_mode():
148      with summary_ops.create_file_writer_v2(logdir).as_default():
149        summary_ops.write('tag', [b'foo', b'bar'], step=12)
150    events = events_from_logdir(logdir)
151    value = events[1].summary.value[0]
152    self.assertAllEqual([b'foo', b'bar'], to_numpy(value))
153
154  @test_util.run_gpu_only
155  def testWrite_gpuDeviceContext(self):
156    logdir = self.get_temp_dir()
157    with context.eager_mode():
158      with summary_ops.create_file_writer_v2(logdir).as_default():
159        with ops.device('/GPU:0'):
160          value = constant_op.constant(42.0)
161          step = constant_op.constant(12, dtype=dtypes.int64)
162          summary_ops.write('tag', value, step=step).numpy()
163    empty_metadata = summary_pb2.SummaryMetadata()
164    events = events_from_logdir(logdir)
165    self.assertEqual(2, len(events))
166    self.assertEqual(12, events[1].step)
167    self.assertEqual(42, to_numpy(events[1].summary.value[0]))
168    self.assertEqual(empty_metadata, events[1].summary.value[0].metadata)
169
170  @test_util.also_run_as_tf_function
171  def testWrite_noDefaultWriter(self):
172    # Use assertAllEqual instead of assertFalse since it works in a defun.
173    self.assertAllEqual(False, summary_ops.write('tag', 42, step=0))
174
175  @test_util.also_run_as_tf_function
176  def testWrite_noStep_okayIfAlsoNoDefaultWriter(self):
177    # Use assertAllEqual instead of assertFalse since it works in a defun.
178    self.assertAllEqual(False, summary_ops.write('tag', 42))
179
180  def testWrite_noStep(self):
181    logdir = self.get_temp_dir()
182    with context.eager_mode():
183      with summary_ops.create_file_writer_v2(logdir).as_default():
184        with self.assertRaisesRegex(ValueError, 'No step set'):
185          summary_ops.write('tag', 42)
186
187  def testWrite_noStep_okayIfNotRecordingSummaries(self):
188    logdir = self.get_temp_dir()
189    with context.eager_mode():
190      with summary_ops.create_file_writer_v2(logdir).as_default():
191        with summary_ops.record_if(False):
192          self.assertFalse(summary_ops.write('tag', 42))
193
194  def testWrite_usingDefaultStep(self):
195    logdir = self.get_temp_dir()
196    try:
197      with context.eager_mode():
198        with summary_ops.create_file_writer_v2(logdir).as_default():
199          summary_ops.set_step(1)
200          summary_ops.write('tag', 1.0)
201          summary_ops.set_step(2)
202          summary_ops.write('tag', 1.0)
203          mystep = variables.Variable(10, dtype=dtypes.int64)
204          summary_ops.set_step(mystep)
205          summary_ops.write('tag', 1.0)
206          mystep.assign_add(1)
207          summary_ops.write('tag', 1.0)
208      events = events_from_logdir(logdir)
209      self.assertEqual(5, len(events))
210      self.assertEqual(1, events[1].step)
211      self.assertEqual(2, events[2].step)
212      self.assertEqual(10, events[3].step)
213      self.assertEqual(11, events[4].step)
214    finally:
215      # Reset to default state for other tests.
216      summary_ops.set_step(None)
217
218  def testWrite_usingDefaultStepConstant_fromFunction(self):
219    logdir = self.get_temp_dir()
220    try:
221      with context.eager_mode():
222        writer = summary_ops.create_file_writer_v2(logdir)
223        @def_function.function
224        def f():
225          with writer.as_default():
226            summary_ops.write('tag', 1.0)
227        summary_ops.set_step(1)
228        f()
229        summary_ops.set_step(2)
230        f()
231      events = events_from_logdir(logdir)
232      self.assertEqual(3, len(events))
233      self.assertEqual(1, events[1].step)
234      # The step value will still be 1 because the value was captured at the
235      # time the function was first traced.
236      self.assertEqual(1, events[2].step)
237    finally:
238      # Reset to default state for other tests.
239      summary_ops.set_step(None)
240
241  def testWrite_usingDefaultStepVariable_fromFunction(self):
242    logdir = self.get_temp_dir()
243    try:
244      with context.eager_mode():
245        writer = summary_ops.create_file_writer_v2(logdir)
246        @def_function.function
247        def f():
248          with writer.as_default():
249            summary_ops.write('tag', 1.0)
250        mystep = variables.Variable(0, dtype=dtypes.int64)
251        summary_ops.set_step(mystep)
252        f()
253        mystep.assign_add(1)
254        f()
255        mystep.assign(10)
256        f()
257      events = events_from_logdir(logdir)
258      self.assertEqual(4, len(events))
259      self.assertEqual(0, events[1].step)
260      self.assertEqual(1, events[2].step)
261      self.assertEqual(10, events[3].step)
262    finally:
263      # Reset to default state for other tests.
264      summary_ops.set_step(None)
265
266  def testWrite_usingDefaultStepConstant_fromLegacyGraph(self):
267    logdir = self.get_temp_dir()
268    try:
269      with context.graph_mode():
270        writer = summary_ops.create_file_writer_v2(logdir)
271        summary_ops.set_step(1)
272        with writer.as_default():
273          write_op = summary_ops.write('tag', 1.0)
274        summary_ops.set_step(2)
275        with self.cached_session() as sess:
276          sess.run(writer.init())
277          sess.run(write_op)
278          sess.run(write_op)
279          sess.run(writer.flush())
280      events = events_from_logdir(logdir)
281      self.assertEqual(3, len(events))
282      self.assertEqual(1, events[1].step)
283      # The step value will still be 1 because the value was captured at the
284      # time the graph was constructed.
285      self.assertEqual(1, events[2].step)
286    finally:
287      # Reset to default state for other tests.
288      summary_ops.set_step(None)
289
290  def testWrite_usingDefaultStepVariable_fromLegacyGraph(self):
291    logdir = self.get_temp_dir()
292    try:
293      with context.graph_mode():
294        writer = summary_ops.create_file_writer_v2(logdir)
295        mystep = variables.Variable(0, dtype=dtypes.int64)
296        summary_ops.set_step(mystep)
297        with writer.as_default():
298          write_op = summary_ops.write('tag', 1.0)
299        first_assign_op = mystep.assign_add(1)
300        second_assign_op = mystep.assign(10)
301        with self.cached_session() as sess:
302          sess.run(writer.init())
303          sess.run(mystep.initializer)
304          sess.run(write_op)
305          sess.run(first_assign_op)
306          sess.run(write_op)
307          sess.run(second_assign_op)
308          sess.run(write_op)
309          sess.run(writer.flush())
310      events = events_from_logdir(logdir)
311      self.assertEqual(4, len(events))
312      self.assertEqual(0, events[1].step)
313      self.assertEqual(1, events[2].step)
314      self.assertEqual(10, events[3].step)
315    finally:
316      # Reset to default state for other tests.
317      summary_ops.set_step(None)
318
319  def testWrite_usingDefaultStep_fromAsDefault(self):
320    logdir = self.get_temp_dir()
321    try:
322      with context.eager_mode():
323        writer = summary_ops.create_file_writer_v2(logdir)
324        with writer.as_default(step=1):
325          summary_ops.write('tag', 1.0)
326          with writer.as_default():
327            summary_ops.write('tag', 1.0)
328            with writer.as_default(step=2):
329              summary_ops.write('tag', 1.0)
330            summary_ops.write('tag', 1.0)
331            summary_ops.set_step(3)
332          summary_ops.write('tag', 1.0)
333      events = events_from_logdir(logdir)
334      self.assertListEqual([1, 1, 2, 1, 3], [e.step for e in events[1:]])
335    finally:
336      # Reset to default state for other tests.
337      summary_ops.set_step(None)
338
339  def testWrite_usingDefaultStepVariable_fromAsDefault(self):
340    logdir = self.get_temp_dir()
341    try:
342      with context.eager_mode():
343        writer = summary_ops.create_file_writer_v2(logdir)
344        mystep = variables.Variable(1, dtype=dtypes.int64)
345        with writer.as_default(step=mystep):
346          summary_ops.write('tag', 1.0)
347          with writer.as_default():
348            mystep.assign(2)
349            summary_ops.write('tag', 1.0)
350            with writer.as_default(step=3):
351              summary_ops.write('tag', 1.0)
352            summary_ops.write('tag', 1.0)
353            mystep.assign(4)
354          summary_ops.write('tag', 1.0)
355      events = events_from_logdir(logdir)
356      self.assertListEqual([1, 2, 3, 2, 4], [e.step for e in events[1:]])
357    finally:
358      # Reset to default state for other tests.
359      summary_ops.set_step(None)
360
361  def testWrite_usingDefaultStep_fromSetAsDefault(self):
362    logdir = self.get_temp_dir()
363    try:
364      with context.eager_mode():
365        writer = summary_ops.create_file_writer_v2(logdir)
366        mystep = variables.Variable(1, dtype=dtypes.int64)
367        writer.set_as_default(step=mystep)
368        summary_ops.write('tag', 1.0)
369        mystep.assign(2)
370        summary_ops.write('tag', 1.0)
371        writer.set_as_default(step=3)
372        summary_ops.write('tag', 1.0)
373        writer.flush()
374      events = events_from_logdir(logdir)
375      self.assertListEqual([1, 2, 3], [e.step for e in events[1:]])
376    finally:
377      # Reset to default state for other tests.
378      summary_ops.set_step(None)
379
380  def testWrite_usingDefaultStepVariable_fromSetAsDefault(self):
381    logdir = self.get_temp_dir()
382    try:
383      with context.eager_mode():
384        writer = summary_ops.create_file_writer_v2(logdir)
385        writer.set_as_default(step=1)
386        summary_ops.write('tag', 1.0)
387        writer.set_as_default(step=2)
388        summary_ops.write('tag', 1.0)
389        writer.set_as_default()
390        summary_ops.write('tag', 1.0)
391        writer.flush()
392      events = events_from_logdir(logdir)
393      self.assertListEqual([1, 2, 2], [e.step for e in events[1:]])
394    finally:
395      # Reset to default state for other tests.
396      summary_ops.set_step(None)
397
398  def testWrite_recordIf_constant(self):
399    logdir = self.get_temp_dir()
400    with context.eager_mode():
401      with summary_ops.create_file_writer_v2(logdir).as_default():
402        self.assertTrue(summary_ops.write('default', 1, step=0))
403        with summary_ops.record_if(True):
404          self.assertTrue(summary_ops.write('set_on', 1, step=0))
405        with summary_ops.record_if(False):
406          self.assertFalse(summary_ops.write('set_off', 1, step=0))
407    events = events_from_logdir(logdir)
408    self.assertEqual(3, len(events))
409    self.assertEqual('default', events[1].summary.value[0].tag)
410    self.assertEqual('set_on', events[2].summary.value[0].tag)
411
412  def testWrite_recordIf_constant_fromFunction(self):
413    logdir = self.get_temp_dir()
414    with context.eager_mode():
415      writer = summary_ops.create_file_writer_v2(logdir)
416      @def_function.function
417      def f():
418        with writer.as_default():
419          # Use assertAllEqual instead of assertTrue since it works in a defun.
420          self.assertAllEqual(summary_ops.write('default', 1, step=0), True)
421          with summary_ops.record_if(True):
422            self.assertAllEqual(summary_ops.write('set_on', 1, step=0), True)
423          with summary_ops.record_if(False):
424            self.assertAllEqual(summary_ops.write('set_off', 1, step=0), False)
425      f()
426    events = events_from_logdir(logdir)
427    self.assertEqual(3, len(events))
428    self.assertEqual('default', events[1].summary.value[0].tag)
429    self.assertEqual('set_on', events[2].summary.value[0].tag)
430
431  def testWrite_recordIf_callable(self):
432    logdir = self.get_temp_dir()
433    with context.eager_mode():
434      step = variables.Variable(-1, dtype=dtypes.int64)
435      def record_fn():
436        step.assign_add(1)
437        return int(step % 2) == 0
438      with summary_ops.create_file_writer_v2(logdir).as_default():
439        with summary_ops.record_if(record_fn):
440          self.assertTrue(summary_ops.write('tag', 1, step=step))
441          self.assertFalse(summary_ops.write('tag', 1, step=step))
442          self.assertTrue(summary_ops.write('tag', 1, step=step))
443          self.assertFalse(summary_ops.write('tag', 1, step=step))
444          self.assertTrue(summary_ops.write('tag', 1, step=step))
445    events = events_from_logdir(logdir)
446    self.assertEqual(4, len(events))
447    self.assertEqual(0, events[1].step)
448    self.assertEqual(2, events[2].step)
449    self.assertEqual(4, events[3].step)
450
451  def testWrite_recordIf_callable_fromFunction(self):
452    logdir = self.get_temp_dir()
453    with context.eager_mode():
454      writer = summary_ops.create_file_writer_v2(logdir)
455      step = variables.Variable(-1, dtype=dtypes.int64)
456      @def_function.function
457      def record_fn():
458        step.assign_add(1)
459        return math_ops.equal(step % 2, 0)
460      @def_function.function
461      def f():
462        with writer.as_default():
463          with summary_ops.record_if(record_fn):
464            return [
465                summary_ops.write('tag', 1, step=step),
466                summary_ops.write('tag', 1, step=step),
467                summary_ops.write('tag', 1, step=step)]
468      self.assertAllEqual(f(), [True, False, True])
469      self.assertAllEqual(f(), [False, True, False])
470    events = events_from_logdir(logdir)
471    self.assertEqual(4, len(events))
472    self.assertEqual(0, events[1].step)
473    self.assertEqual(2, events[2].step)
474    self.assertEqual(4, events[3].step)
475
476  def testWrite_recordIf_tensorInput_fromFunction(self):
477    logdir = self.get_temp_dir()
478    with context.eager_mode():
479      writer = summary_ops.create_file_writer_v2(logdir)
480      @def_function.function(input_signature=[
481          tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)])
482      def f(step):
483        with writer.as_default():
484          with summary_ops.record_if(math_ops.equal(step % 2, 0)):
485            return summary_ops.write('tag', 1, step=step)
486      self.assertTrue(f(0))
487      self.assertFalse(f(1))
488      self.assertTrue(f(2))
489      self.assertFalse(f(3))
490      self.assertTrue(f(4))
491    events = events_from_logdir(logdir)
492    self.assertEqual(4, len(events))
493    self.assertEqual(0, events[1].step)
494    self.assertEqual(2, events[2].step)
495    self.assertEqual(4, events[3].step)
496
497  def testWriteRawPb(self):
498    logdir = self.get_temp_dir()
499    pb = summary_pb2.Summary()
500    pb.value.add().simple_value = 42.0
501    with context.eager_mode():
502      with summary_ops.create_file_writer_v2(logdir).as_default():
503        output = summary_ops.write_raw_pb(pb.SerializeToString(), step=12)
504        self.assertTrue(output.numpy())
505    events = events_from_logdir(logdir)
506    self.assertEqual(2, len(events))
507    self.assertEqual(12, events[1].step)
508    self.assertProtoEquals(pb, events[1].summary)
509
510  def testWriteRawPb_fromFunction(self):
511    logdir = self.get_temp_dir()
512    pb = summary_pb2.Summary()
513    pb.value.add().simple_value = 42.0
514    with context.eager_mode():
515      writer = summary_ops.create_file_writer_v2(logdir)
516      @def_function.function
517      def f():
518        with writer.as_default():
519          return summary_ops.write_raw_pb(pb.SerializeToString(), step=12)
520      output = f()
521      self.assertTrue(output.numpy())
522    events = events_from_logdir(logdir)
523    self.assertEqual(2, len(events))
524    self.assertEqual(12, events[1].step)
525    self.assertProtoEquals(pb, events[1].summary)
526
527  def testWriteRawPb_multipleValues(self):
528    logdir = self.get_temp_dir()
529    pb1 = summary_pb2.Summary()
530    pb1.value.add().simple_value = 1.0
531    pb1.value.add().simple_value = 2.0
532    pb2 = summary_pb2.Summary()
533    pb2.value.add().simple_value = 3.0
534    pb3 = summary_pb2.Summary()
535    pb3.value.add().simple_value = 4.0
536    pb3.value.add().simple_value = 5.0
537    pb3.value.add().simple_value = 6.0
538    pbs = [pb.SerializeToString() for pb in (pb1, pb2, pb3)]
539    with context.eager_mode():
540      with summary_ops.create_file_writer_v2(logdir).as_default():
541        output = summary_ops.write_raw_pb(pbs, step=12)
542        self.assertTrue(output.numpy())
543    events = events_from_logdir(logdir)
544    self.assertEqual(2, len(events))
545    self.assertEqual(12, events[1].step)
546    expected_pb = summary_pb2.Summary()
547    for i in range(6):
548      expected_pb.value.add().simple_value = i + 1.0
549    self.assertProtoEquals(expected_pb, events[1].summary)
550
551  def testWriteRawPb_invalidValue(self):
552    logdir = self.get_temp_dir()
553    with context.eager_mode():
554      with summary_ops.create_file_writer_v2(logdir).as_default():
555        with self.assertRaisesRegex(
556            errors.DataLossError,
557            'Bad tf.compat.v1.Summary binary proto tensor string'):
558          summary_ops.write_raw_pb('notaproto', step=12)
559
560  @test_util.also_run_as_tf_function
561  def testGetSetStep(self):
562    try:
563      self.assertIsNone(summary_ops.get_step())
564      summary_ops.set_step(1)
565      # Use assertAllEqual instead of assertEqual since it works in a defun.
566      self.assertAllEqual(1, summary_ops.get_step())
567      summary_ops.set_step(constant_op.constant(2))
568      self.assertAllEqual(2, summary_ops.get_step())
569    finally:
570      # Reset to default state for other tests.
571      summary_ops.set_step(None)
572
573  def testGetSetStep_variable(self):
574    with context.eager_mode():
575      try:
576        mystep = variables.Variable(0)
577        summary_ops.set_step(mystep)
578        self.assertAllEqual(0, summary_ops.get_step().read_value())
579        mystep.assign_add(1)
580        self.assertAllEqual(1, summary_ops.get_step().read_value())
581        # Check that set_step() properly maintains reference to variable.
582        del mystep
583        self.assertAllEqual(1, summary_ops.get_step().read_value())
584        summary_ops.get_step().assign_add(1)
585        self.assertAllEqual(2, summary_ops.get_step().read_value())
586      finally:
587        # Reset to default state for other tests.
588        summary_ops.set_step(None)
589
590  def testGetSetStep_variable_fromFunction(self):
591    with context.eager_mode():
592      try:
593        @def_function.function
594        def set_step(step):
595          summary_ops.set_step(step)
596          return summary_ops.get_step()
597        @def_function.function
598        def get_and_increment():
599          summary_ops.get_step().assign_add(1)
600          return summary_ops.get_step()
601        mystep = variables.Variable(0)
602        self.assertAllEqual(0, set_step(mystep))
603        self.assertAllEqual(0, summary_ops.get_step().read_value())
604        self.assertAllEqual(1, get_and_increment())
605        self.assertAllEqual(2, get_and_increment())
606        # Check that set_step() properly maintains reference to variable.
607        del mystep
608        self.assertAllEqual(3, get_and_increment())
609      finally:
610        # Reset to default state for other tests.
611        summary_ops.set_step(None)
612
613  @test_util.also_run_as_tf_function
614  def testSummaryScope(self):
615    with summary_ops.summary_scope('foo') as (tag, scope):
616      self.assertEqual('foo', tag)
617      self.assertEqual('foo/', scope)
618      with summary_ops.summary_scope('bar') as (tag, scope):
619        self.assertEqual('foo/bar', tag)
620        self.assertEqual('foo/bar/', scope)
621      with summary_ops.summary_scope('with/slash') as (tag, scope):
622        self.assertEqual('foo/with/slash', tag)
623        self.assertEqual('foo/with/slash/', scope)
624      with ops.name_scope(None, skip_on_eager=False):
625        with summary_ops.summary_scope('unnested') as (tag, scope):
626          self.assertEqual('unnested', tag)
627          self.assertEqual('unnested/', scope)
628
629  @test_util.also_run_as_tf_function
630  def testSummaryScope_defaultName(self):
631    with summary_ops.summary_scope(None) as (tag, scope):
632      self.assertEqual('summary', tag)
633      self.assertEqual('summary/', scope)
634    with summary_ops.summary_scope(None, 'backup') as (tag, scope):
635      self.assertEqual('backup', tag)
636      self.assertEqual('backup/', scope)
637
638  @test_util.also_run_as_tf_function
639  def testSummaryScope_handlesCharactersIllegalForScope(self):
640    with summary_ops.summary_scope('f?o?o') as (tag, scope):
641      self.assertEqual('f?o?o', tag)
642      self.assertEqual('foo/', scope)
643    # If all characters aren't legal for a scope name, use default name.
644    with summary_ops.summary_scope('???', 'backup') as (tag, scope):
645      self.assertEqual('???', tag)
646      self.assertEqual('backup/', scope)
647
648  @test_util.also_run_as_tf_function
649  def testSummaryScope_nameNotUniquifiedForTag(self):
650    constant_op.constant(0, name='foo')
651    with summary_ops.summary_scope('foo') as (tag, _):
652      self.assertEqual('foo', tag)
653    with summary_ops.summary_scope('foo') as (tag, _):
654      self.assertEqual('foo', tag)
655    with ops.name_scope('with', skip_on_eager=False):
656      constant_op.constant(0, name='slash')
657    with summary_ops.summary_scope('with/slash') as (tag, _):
658      self.assertEqual('with/slash', tag)
659
660  def testAllV2SummaryOps(self):
661    logdir = self.get_temp_dir()
662    def define_ops():
663      result = []
664      # TF 2.0 summary ops
665      result.append(summary_ops.write('write', 1, step=0))
666      result.append(summary_ops.write_raw_pb(b'', step=0, name='raw_pb'))
667      # TF 1.x tf.contrib.summary ops
668      result.append(summary_ops.generic('tensor', 1, step=1))
669      result.append(summary_ops.scalar('scalar', 2.0, step=1))
670      result.append(summary_ops.histogram('histogram', [1.0], step=1))
671      result.append(summary_ops.image('image', [[[[1.0]]]], step=1))
672      result.append(summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1))
673      return result
674    with context.graph_mode():
675      ops_without_writer = define_ops()
676      with summary_ops.create_file_writer_v2(logdir).as_default():
677        with summary_ops.record_if(True):
678          ops_recording_on = define_ops()
679        with summary_ops.record_if(False):
680          ops_recording_off = define_ops()
681      # We should be collecting all ops defined with a default writer present,
682      # regardless of whether recording was set on or off, but not those defined
683      # without a writer at all.
684      del ops_without_writer
685      expected_ops = ops_recording_on + ops_recording_off
686      self.assertCountEqual(expected_ops, summary_ops.all_v2_summary_ops())
687
688  def testShouldRecordSummaries_defaultState(self):
689    logdir = self.get_temp_dir()
690    with context.eager_mode():
691      self.assertAllEqual(False, summary_ops.should_record_summaries())
692      w = summary_ops.create_file_writer_v2(logdir)
693      self.assertAllEqual(False, summary_ops.should_record_summaries())
694      with w.as_default():
695        # Should be enabled only when default writer is registered.
696        self.assertAllEqual(True, summary_ops.should_record_summaries())
697      self.assertAllEqual(False, summary_ops.should_record_summaries())
698      with summary_ops.record_if(True):
699        # Should be disabled when no default writer, even with record_if(True).
700        self.assertAllEqual(False, summary_ops.should_record_summaries())
701
702  def testShouldRecordSummaries_constants(self):
703    logdir = self.get_temp_dir()
704    with context.eager_mode():
705      with summary_ops.create_file_writer_v2(logdir).as_default():
706        with summary_ops.record_if(True):
707          self.assertAllEqual(True, summary_ops.should_record_summaries())
708        with summary_ops.record_if(False):
709          self.assertAllEqual(False, summary_ops.should_record_summaries())
710          with summary_ops.record_if(True):
711            self.assertAllEqual(True, summary_ops.should_record_summaries())
712
713  def testShouldRecordSummaries_variable(self):
714    logdir = self.get_temp_dir()
715    with context.eager_mode():
716      with summary_ops.create_file_writer_v2(logdir).as_default():
717        cond = variables.Variable(False)
718        with summary_ops.record_if(cond):
719          self.assertAllEqual(False, summary_ops.should_record_summaries())
720          cond.assign(True)
721          self.assertAllEqual(True, summary_ops.should_record_summaries())
722
723  def testShouldRecordSummaries_callable(self):
724    logdir = self.get_temp_dir()
725    with context.eager_mode():
726      with summary_ops.create_file_writer_v2(logdir).as_default():
727        cond_box = [False]
728        cond = lambda: cond_box[0]
729        with summary_ops.record_if(cond):
730          self.assertAllEqual(False, summary_ops.should_record_summaries())
731          cond_box[0] = True
732          self.assertAllEqual(True, summary_ops.should_record_summaries())
733
734  def testShouldRecordSummaries_fromFunction(self):
735    logdir = self.get_temp_dir()
736    with context.eager_mode():
737      writer = summary_ops.create_file_writer_v2(logdir)
738      @def_function.function(input_signature=[
739          tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool)])
740      def f(cond):
741        results = []
742        results.append(summary_ops.should_record_summaries())
743        with writer.as_default():
744          results.append(summary_ops.should_record_summaries())
745          with summary_ops.record_if(False):
746            results.append(summary_ops.should_record_summaries())
747          with summary_ops.record_if(cond):
748            results.append(summary_ops.should_record_summaries())
749        return results
750      self.assertAllEqual([False, True, False, True], f(True))
751      self.assertAllEqual([False, True, False, False], f(False))
752
753
754class SummaryWriterTest(test_util.TensorFlowTestCase):
755
756  def testCreate_withInitAndClose(self):
757    logdir = self.get_temp_dir()
758    with context.eager_mode():
759      writer = summary_ops.create_file_writer_v2(
760          logdir, max_queue=1000, flush_millis=1000000)
761      get_total = lambda: len(events_from_logdir(logdir))
762      self.assertEqual(1, get_total())  # file_version Event
763      # Calling init() again while writer is open has no effect
764      writer.init()
765      self.assertEqual(1, get_total())
766      with writer.as_default():
767        summary_ops.write('tag', 1, step=0)
768        self.assertEqual(1, get_total())
769        # Calling .close() should do an implicit flush
770        writer.close()
771        self.assertEqual(2, get_total())
772
773  def testCreate_fromFunction(self):
774    logdir = self.get_temp_dir()
775    @def_function.function
776    def f():
777      # Returned SummaryWriter must be stored in a non-local variable so it
778      # lives throughout the function execution.
779      if not hasattr(f, 'writer'):
780        f.writer = summary_ops.create_file_writer_v2(logdir)
781    with context.eager_mode():
782      f()
783    event_files = gfile.Glob(os.path.join(logdir, '*'))
784    self.assertEqual(1, len(event_files))
785
786  def testCreate_graphTensorArgument_raisesError(self):
787    logdir = self.get_temp_dir()
788    with context.graph_mode():
789      logdir_tensor = constant_op.constant(logdir)
790    with context.eager_mode():
791      with self.assertRaisesRegex(
792          ValueError, 'Invalid graph Tensor argument.*logdir'):
793        summary_ops.create_file_writer_v2(logdir_tensor)
794    self.assertEmpty(gfile.Glob(os.path.join(logdir, '*')))
795
796  def testCreate_fromFunction_graphTensorArgument_raisesError(self):
797    logdir = self.get_temp_dir()
798    @def_function.function
799    def f():
800      summary_ops.create_file_writer_v2(constant_op.constant(logdir))
801    with context.eager_mode():
802      with self.assertRaisesRegex(
803          ValueError, 'Invalid graph Tensor argument.*logdir'):
804        f()
805    self.assertEmpty(gfile.Glob(os.path.join(logdir, '*')))
806
807  def testCreate_fromFunction_unpersistedResource_raisesError(self):
808    logdir = self.get_temp_dir()
809    @def_function.function
810    def f():
811      with summary_ops.create_file_writer_v2(logdir).as_default():
812        pass  # Calling .as_default() is enough to indicate use.
813    with context.eager_mode():
814      # TODO(nickfelt): change this to a better error
815      with self.assertRaisesRegex(
816          errors.NotFoundError, 'Resource.*does not exist'):
817        f()
818    # Even though we didn't use it, an event file will have been created.
819    self.assertEqual(1, len(gfile.Glob(os.path.join(logdir, '*'))))
820
821  def testCreate_immediateSetAsDefault_retainsReference(self):
822    logdir = self.get_temp_dir()
823    try:
824      with context.eager_mode():
825        summary_ops.create_file_writer_v2(logdir).set_as_default()
826        summary_ops.flush()
827    finally:
828      # Ensure we clean up no matter how the test executes.
829      summary_ops._summary_state.writer = None  # pylint: disable=protected-access
830
831  def testCreate_immediateAsDefault_retainsReference(self):
832    logdir = self.get_temp_dir()
833    with context.eager_mode():
834      with summary_ops.create_file_writer_v2(logdir).as_default():
835        summary_ops.flush()
836
837  def testCreate_avoidsFilenameCollision(self):
838    logdir = self.get_temp_dir()
839    with context.eager_mode():
840      for _ in range(10):
841        summary_ops.create_file_writer_v2(logdir)
842    event_files = gfile.Glob(os.path.join(logdir, '*'))
843    self.assertLen(event_files, 10)
844
845  def testCreate_graphMode_avoidsFilenameCollision(self):
846    logdir = self.get_temp_dir()
847    with context.graph_mode(), ops.Graph().as_default():
848      writer = summary_ops.create_file_writer_v2(logdir)
849      with self.cached_session() as sess:
850        for _ in range(10):
851          sess.run(writer.init())
852          sess.run(writer.close())
853    event_files = gfile.Glob(os.path.join(logdir, '*'))
854    self.assertLen(event_files, 10)
855
856  def testNoSharing(self):
857    # Two writers with the same logdir should not share state.
858    logdir = self.get_temp_dir()
859    with context.eager_mode():
860      writer1 = summary_ops.create_file_writer_v2(logdir)
861      with writer1.as_default():
862        summary_ops.write('tag', 1, step=1)
863      event_files = gfile.Glob(os.path.join(logdir, '*'))
864      self.assertEqual(1, len(event_files))
865      file1 = event_files[0]
866
867      writer2 = summary_ops.create_file_writer_v2(logdir)
868      with writer2.as_default():
869        summary_ops.write('tag', 1, step=2)
870      event_files = gfile.Glob(os.path.join(logdir, '*'))
871      self.assertEqual(2, len(event_files))
872      event_files.remove(file1)
873      file2 = event_files[0]
874
875      # Extra writes to ensure interleaved usage works.
876      with writer1.as_default():
877        summary_ops.write('tag', 1, step=1)
878      with writer2.as_default():
879        summary_ops.write('tag', 1, step=2)
880
881    events = iter(events_from_file(file1))
882    self.assertEqual('brain.Event:2', next(events).file_version)
883    self.assertEqual(1, next(events).step)
884    self.assertEqual(1, next(events).step)
885    self.assertRaises(StopIteration, lambda: next(events))
886    events = iter(events_from_file(file2))
887    self.assertEqual('brain.Event:2', next(events).file_version)
888    self.assertEqual(2, next(events).step)
889    self.assertEqual(2, next(events).step)
890    self.assertRaises(StopIteration, lambda: next(events))
891
892  def testNoSharing_fromFunction(self):
893    logdir = self.get_temp_dir()
894    @def_function.function
895    def f1():
896      if not hasattr(f1, 'writer'):
897        f1.writer = summary_ops.create_file_writer_v2(logdir)
898      with f1.writer.as_default():
899        summary_ops.write('tag', 1, step=1)
900    @def_function.function
901    def f2():
902      if not hasattr(f2, 'writer'):
903        f2.writer = summary_ops.create_file_writer_v2(logdir)
904      with f2.writer.as_default():
905        summary_ops.write('tag', 1, step=2)
906    with context.eager_mode():
907      f1()
908      event_files = gfile.Glob(os.path.join(logdir, '*'))
909      self.assertEqual(1, len(event_files))
910      file1 = event_files[0]
911
912      f2()
913      event_files = gfile.Glob(os.path.join(logdir, '*'))
914      self.assertEqual(2, len(event_files))
915      event_files.remove(file1)
916      file2 = event_files[0]
917
918      # Extra writes to ensure interleaved usage works.
919      f1()
920      f2()
921
922    events = iter(events_from_file(file1))
923    self.assertEqual('brain.Event:2', next(events).file_version)
924    self.assertEqual(1, next(events).step)
925    self.assertEqual(1, next(events).step)
926    self.assertRaises(StopIteration, lambda: next(events))
927    events = iter(events_from_file(file2))
928    self.assertEqual('brain.Event:2', next(events).file_version)
929    self.assertEqual(2, next(events).step)
930    self.assertEqual(2, next(events).step)
931    self.assertRaises(StopIteration, lambda: next(events))
932
933  def testMaxQueue(self):
934    logdir = self.get_temp_dir()
935    with context.eager_mode():
936      with summary_ops.create_file_writer_v2(
937          logdir, max_queue=1, flush_millis=999999).as_default():
938        get_total = lambda: len(events_from_logdir(logdir))
939        # Note: First tf.compat.v1.Event is always file_version.
940        self.assertEqual(1, get_total())
941        summary_ops.write('tag', 1, step=0)
942        self.assertEqual(1, get_total())
943        # Should flush after second summary since max_queue = 1
944        summary_ops.write('tag', 1, step=0)
945        self.assertEqual(3, get_total())
946
947  def testWriterFlush(self):
948    logdir = self.get_temp_dir()
949    get_total = lambda: len(events_from_logdir(logdir))
950    with context.eager_mode():
951      writer = summary_ops.create_file_writer_v2(
952          logdir, max_queue=1000, flush_millis=1000000)
953      self.assertEqual(1, get_total())  # file_version Event
954      with writer.as_default():
955        summary_ops.write('tag', 1, step=0)
956        self.assertEqual(1, get_total())
957        writer.flush()
958        self.assertEqual(2, get_total())
959        summary_ops.write('tag', 1, step=0)
960        self.assertEqual(2, get_total())
961      # Exiting the "as_default()" should do an implicit flush
962      self.assertEqual(3, get_total())
963
964  def testFlushFunction(self):
965    logdir = self.get_temp_dir()
966    with context.eager_mode():
967      writer = summary_ops.create_file_writer_v2(
968          logdir, max_queue=999999, flush_millis=999999)
969      with writer.as_default():
970        get_total = lambda: len(events_from_logdir(logdir))
971        # Note: First tf.compat.v1.Event is always file_version.
972        self.assertEqual(1, get_total())
973        summary_ops.write('tag', 1, step=0)
974        summary_ops.write('tag', 1, step=0)
975        self.assertEqual(1, get_total())
976        summary_ops.flush()
977        self.assertEqual(3, get_total())
978        # Test "writer" parameter
979        summary_ops.write('tag', 1, step=0)
980        self.assertEqual(3, get_total())
981        summary_ops.flush(writer=writer)
982        self.assertEqual(4, get_total())
983        summary_ops.write('tag', 1, step=0)
984        self.assertEqual(4, get_total())
985        summary_ops.flush(writer=writer._resource)  # pylint:disable=protected-access
986        self.assertEqual(5, get_total())
987
988  @test_util.assert_no_new_tensors
989  def testNoMemoryLeak_graphMode(self):
990    logdir = self.get_temp_dir()
991    with context.graph_mode(), ops.Graph().as_default():
992      summary_ops.create_file_writer_v2(logdir)
993
994  @test_util.assert_no_new_pyobjects_executing_eagerly
995  def testNoMemoryLeak_eagerMode(self):
996    logdir = self.get_temp_dir()
997    with summary_ops.create_file_writer_v2(logdir).as_default():
998      summary_ops.write('tag', 1, step=0)
999
1000  def testClose_preventsLaterUse(self):
1001    logdir = self.get_temp_dir()
1002    with context.eager_mode():
1003      writer = summary_ops.create_file_writer_v2(logdir)
1004      writer.close()
1005      writer.close()  # redundant close() is a no-op
1006      writer.flush()  # redundant flush() is a no-op
1007      with self.assertRaisesRegex(RuntimeError, 'already closed'):
1008        writer.init()
1009      with self.assertRaisesRegex(RuntimeError, 'already closed'):
1010        with writer.as_default():
1011          self.fail('should not get here')
1012      with self.assertRaisesRegex(RuntimeError, 'already closed'):
1013        writer.set_as_default()
1014
1015  def testClose_closesOpenFile(self):
1016    try:
1017      import psutil  # pylint: disable=g-import-not-at-top
1018    except ImportError:
1019      raise unittest.SkipTest('test requires psutil')
1020    proc = psutil.Process()
1021    get_open_filenames = lambda: set(info[0] for info in proc.open_files())
1022    logdir = self.get_temp_dir()
1023    with context.eager_mode():
1024      writer = summary_ops.create_file_writer_v2(logdir)
1025      files = gfile.Glob(os.path.join(logdir, '*'))
1026      self.assertEqual(1, len(files))
1027      eventfile = files[0]
1028      self.assertIn(eventfile, get_open_filenames())
1029      writer.close()
1030      self.assertNotIn(eventfile, get_open_filenames())
1031
1032  def testDereference_closesOpenFile(self):
1033    try:
1034      import psutil  # pylint: disable=g-import-not-at-top
1035    except ImportError:
1036      raise unittest.SkipTest('test requires psutil')
1037    proc = psutil.Process()
1038    get_open_filenames = lambda: set(info[0] for info in proc.open_files())
1039    logdir = self.get_temp_dir()
1040    with context.eager_mode():
1041      writer = summary_ops.create_file_writer_v2(logdir)
1042      files = gfile.Glob(os.path.join(logdir, '*'))
1043      self.assertEqual(1, len(files))
1044      eventfile = files[0]
1045      self.assertIn(eventfile, get_open_filenames())
1046      del writer
1047      self.assertNotIn(eventfile, get_open_filenames())
1048
1049
1050class SummaryWriterSavedModelTest(test_util.TensorFlowTestCase):
1051
1052  def testWriter_savedAsModuleProperty_loadInEagerMode(self):
1053    with context.eager_mode():
1054      class Model(module.Module):
1055
1056        def __init__(self, model_dir):
1057          self._writer = summary_ops.create_file_writer_v2(
1058              model_dir, experimental_trackable=True)
1059
1060        @def_function.function(input_signature=[
1061            tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)
1062        ])
1063        def train(self, step):
1064          with self._writer.as_default():
1065            summary_ops.write('tag', 'foo', step=step)
1066          return constant_op.constant(0)
1067
1068      logdir = self.get_temp_dir()
1069      to_export = Model(logdir)
1070      pre_save_files = set(events_from_multifile_logdir(logdir))
1071      export_dir = os.path.join(logdir, 'export')
1072      saved_model_save.save(
1073          to_export, export_dir, signatures={'train': to_export.train})
1074
1075    # Reset context to ensure we don't share any resources with saving code.
1076    context._reset_context()  # pylint: disable=protected-access
1077    with context.eager_mode():
1078      restored = saved_model_load.load(export_dir)
1079      restored.train(1)
1080      restored.train(2)
1081      post_restore_files = set(events_from_multifile_logdir(logdir))
1082      restored2 = saved_model_load.load(export_dir)
1083      restored2.train(3)
1084      restored2.train(4)
1085      files_to_events = events_from_multifile_logdir(logdir)
1086      post_restore2_files = set(files_to_events)
1087      self.assertLen(files_to_events, 3)
1088      def unwrap_singleton(iterable):
1089        self.assertLen(iterable, 1)
1090        return next(iter(iterable))
1091      restore_file = unwrap_singleton(post_restore_files - pre_save_files)
1092      restore2_file = unwrap_singleton(post_restore2_files - post_restore_files)
1093      restore_events = files_to_events[restore_file]
1094      restore2_events = files_to_events[restore2_file]
1095      self.assertLen(restore_events, 3)
1096      self.assertEqual(1, restore_events[1].step)
1097      self.assertEqual(2, restore_events[2].step)
1098      self.assertLen(restore2_events, 3)
1099      self.assertEqual(3, restore2_events[1].step)
1100      self.assertEqual(4, restore2_events[2].step)
1101
1102  def testWriter_savedAsModuleProperty_loadInGraphMode(self):
1103    with context.eager_mode():
1104
1105      class Model(module.Module):
1106
1107        def __init__(self, model_dir):
1108          self._writer = summary_ops.create_file_writer_v2(
1109              model_dir, experimental_trackable=True)
1110
1111        @def_function.function(input_signature=[
1112            tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)
1113        ])
1114        def train(self, step):
1115          with self._writer.as_default():
1116            summary_ops.write('tag', 'foo', step=step)
1117          return constant_op.constant(0)
1118
1119      logdir = self.get_temp_dir()
1120      to_export = Model(logdir)
1121      pre_save_files = set(events_from_multifile_logdir(logdir))
1122      export_dir = os.path.join(logdir, 'export')
1123      saved_model_save.save(
1124          to_export, export_dir, signatures={'train': to_export.train})
1125
1126    # Reset context to ensure we don't share any resources with saving code.
1127    context._reset_context()  # pylint: disable=protected-access
1128
1129    def load_and_run_model(sess, input_values):
1130      """Load and run the SavedModel signature in the TF 1.x style."""
1131      model = saved_model_loader.load(sess, [tag_constants.SERVING], export_dir)
1132      signature = model.signature_def['train']
1133      inputs = list(signature.inputs.values())
1134      assert len(inputs) == 1, inputs
1135      outputs = list(signature.outputs.values())
1136      assert len(outputs) == 1, outputs
1137      input_tensor = sess.graph.get_tensor_by_name(inputs[0].name)
1138      output_tensor = sess.graph.get_tensor_by_name(outputs[0].name)
1139      for v in input_values:
1140        sess.run(output_tensor, feed_dict={input_tensor: v})
1141
1142    with context.graph_mode(), ops.Graph().as_default():
1143      # Since writer shared_name is fixed, within a single session, all loads of
1144      # this SavedModel will refer to a single writer resouce, so it will be
1145      # initialized only once and write to a single file.
1146      with self.session() as sess:
1147        load_and_run_model(sess, [1, 2])
1148        load_and_run_model(sess, [3, 4])
1149      post_restore_files = set(events_from_multifile_logdir(logdir))
1150      # New session will recreate the resource and write to a second file.
1151      with self.session() as sess:
1152        load_and_run_model(sess, [5, 6])
1153      files_to_events = events_from_multifile_logdir(logdir)
1154      post_restore2_files = set(files_to_events)
1155
1156    self.assertLen(files_to_events, 3)
1157    def unwrap_singleton(iterable):
1158      self.assertLen(iterable, 1)
1159      return next(iter(iterable))
1160    restore_file = unwrap_singleton(post_restore_files - pre_save_files)
1161    restore2_file = unwrap_singleton(post_restore2_files - post_restore_files)
1162    restore_events = files_to_events[restore_file]
1163    restore2_events = files_to_events[restore2_file]
1164    self.assertLen(restore_events, 5)
1165    self.assertEqual(1, restore_events[1].step)
1166    self.assertEqual(2, restore_events[2].step)
1167    self.assertEqual(3, restore_events[3].step)
1168    self.assertEqual(4, restore_events[4].step)
1169    self.assertLen(restore2_events, 3)
1170    self.assertEqual(5, restore2_events[1].step)
1171    self.assertEqual(6, restore2_events[2].step)
1172
1173
1174class NoopWriterTest(test_util.TensorFlowTestCase):
1175
1176  def testNoopWriter_doesNothing(self):
1177    logdir = self.get_temp_dir()
1178    with context.eager_mode():
1179      writer = summary_ops.create_noop_writer()
1180      writer.init()
1181      with writer.as_default():
1182        result = summary_ops.write('test', 1.0, step=0)
1183      writer.flush()
1184      writer.close()
1185    self.assertFalse(result)  # Should have found no active writer
1186    files = gfile.Glob(os.path.join(logdir, '*'))
1187    self.assertLen(files, 0)
1188
1189  def testNoopWriter_asNestedContext_isTransparent(self):
1190    logdir = self.get_temp_dir()
1191    with context.eager_mode():
1192      writer = summary_ops.create_file_writer_v2(logdir)
1193      noop_writer = summary_ops.create_noop_writer()
1194      with writer.as_default():
1195        result1 = summary_ops.write('first', 1.0, step=0)
1196        with noop_writer.as_default():
1197          result2 = summary_ops.write('second', 1.0, step=0)
1198        result3 = summary_ops.write('third', 1.0, step=0)
1199    # All ops should have written, including the one inside the no-op writer,
1200    # since it doesn't actively *disable* writing - it just behaves as if that
1201    # entire `with` block wasn't there at all.
1202    self.assertAllEqual([result1, result2, result3], [True, True, True])
1203
1204  def testNoopWriter_setAsDefault(self):
1205    try:
1206      with context.eager_mode():
1207        writer = summary_ops.create_noop_writer()
1208        writer.set_as_default()
1209        result = summary_ops.write('test', 1.0, step=0)
1210      self.assertFalse(result)  # Should have found no active writer
1211    finally:
1212      # Ensure we clean up no matter how the test executes.
1213      summary_ops._summary_state.writer = None  # pylint: disable=protected-access
1214
1215
1216class SummaryOpsTest(test_util.TensorFlowTestCase):
1217
1218  def tearDown(self):
1219    summary_ops.trace_off()
1220
1221  def exec_summary_op(self, summary_op_fn):
1222    assert context.executing_eagerly()
1223    logdir = self.get_temp_dir()
1224    writer = summary_ops.create_file_writer_v2(logdir)
1225    with writer.as_default():
1226      summary_op_fn()
1227    writer.close()
1228    events = events_from_logdir(logdir)
1229    return events[1]
1230
1231  def run_metadata(self, *args, **kwargs):
1232    assert context.executing_eagerly()
1233    logdir = self.get_temp_dir()
1234    writer = summary_ops.create_file_writer_v2(logdir)
1235    with writer.as_default():
1236      summary_ops.run_metadata(*args, **kwargs)
1237    writer.close()
1238    events = events_from_logdir(logdir)
1239    return events[1]
1240
1241  def run_metadata_graphs(self, *args, **kwargs):
1242    assert context.executing_eagerly()
1243    logdir = self.get_temp_dir()
1244    writer = summary_ops.create_file_writer_v2(logdir)
1245    with writer.as_default():
1246      summary_ops.run_metadata_graphs(*args, **kwargs)
1247    writer.close()
1248    events = events_from_logdir(logdir)
1249    return events[1]
1250
1251  def create_run_metadata(self):
1252    step_stats = step_stats_pb2.StepStats(dev_stats=[
1253        step_stats_pb2.DeviceStepStats(
1254            device='cpu:0',
1255            node_stats=[step_stats_pb2.NodeExecStats(node_name='hello')])
1256    ])
1257    return config_pb2.RunMetadata(
1258        function_graphs=[
1259            config_pb2.RunMetadata.FunctionGraphs(
1260                pre_optimization_graph=graph_pb2.GraphDef(
1261                    node=[node_def_pb2.NodeDef(name='foo')]))
1262        ],
1263        step_stats=step_stats)
1264
1265  def run_trace(self, f, step=1):
1266    assert context.executing_eagerly()
1267    logdir = self.get_temp_dir()
1268    writer = summary_ops.create_file_writer_v2(logdir)
1269    summary_ops.trace_on(graph=True, profiler=False)
1270    with writer.as_default():
1271      f()
1272      summary_ops.trace_export(name='foo', step=step)
1273    writer.close()
1274    events = events_from_logdir(logdir)
1275    return events[1]
1276
1277  @test_util.run_v2_only
1278  def testRunMetadata_usesNameAsTag(self):
1279    meta = config_pb2.RunMetadata()
1280
1281    with ops.name_scope('foo', skip_on_eager=False):
1282      event = self.run_metadata(name='my_name', data=meta, step=1)
1283      first_val = event.summary.value[0]
1284
1285    self.assertEqual('foo/my_name', first_val.tag)
1286
1287  @test_util.run_v2_only
1288  def testRunMetadata_summaryMetadata(self):
1289    expected_summary_metadata = """
1290      plugin_data {
1291        plugin_name: "graph_run_metadata"
1292        content: "1"
1293      }
1294    """
1295    meta = config_pb2.RunMetadata()
1296    event = self.run_metadata(name='my_name', data=meta, step=1)
1297    actual_summary_metadata = event.summary.value[0].metadata
1298    self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata)
1299
1300  @test_util.run_v2_only
1301  def testRunMetadata_wholeRunMetadata(self):
1302    expected_run_metadata = """
1303      step_stats {
1304        dev_stats {
1305          device: "cpu:0"
1306          node_stats {
1307            node_name: "hello"
1308          }
1309        }
1310      }
1311      function_graphs {
1312        pre_optimization_graph {
1313          node {
1314            name: "foo"
1315          }
1316        }
1317      }
1318    """
1319    meta = self.create_run_metadata()
1320    event = self.run_metadata(name='my_name', data=meta, step=1)
1321    first_val = event.summary.value[0]
1322
1323    actual_run_metadata = config_pb2.RunMetadata.FromString(
1324        first_val.tensor.string_val[0])
1325    self.assertProtoEquals(expected_run_metadata, actual_run_metadata)
1326
1327  @test_util.run_v2_only
1328  def testRunMetadata_usesDefaultStep(self):
1329    meta = config_pb2.RunMetadata()
1330    try:
1331      summary_ops.set_step(42)
1332      event = self.run_metadata(name='my_name', data=meta)
1333      self.assertEqual(42, event.step)
1334    finally:
1335      # Reset to default state for other tests.
1336      summary_ops.set_step(None)
1337
1338  @test_util.run_v2_only
1339  def testRunMetadataGraph_usesNameAsTag(self):
1340    meta = config_pb2.RunMetadata()
1341
1342    with ops.name_scope('foo', skip_on_eager=False):
1343      event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
1344      first_val = event.summary.value[0]
1345
1346    self.assertEqual('foo/my_name', first_val.tag)
1347
1348  @test_util.run_v2_only
1349  def testRunMetadataGraph_summaryMetadata(self):
1350    expected_summary_metadata = """
1351      plugin_data {
1352        plugin_name: "graph_run_metadata_graph"
1353        content: "1"
1354      }
1355    """
1356    meta = config_pb2.RunMetadata()
1357    event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
1358    actual_summary_metadata = event.summary.value[0].metadata
1359    self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata)
1360
1361  @test_util.run_v2_only
1362  def testRunMetadataGraph_runMetadataFragment(self):
1363    expected_run_metadata = """
1364      function_graphs {
1365        pre_optimization_graph {
1366          node {
1367            name: "foo"
1368          }
1369        }
1370      }
1371    """
1372    meta = self.create_run_metadata()
1373
1374    event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
1375    first_val = event.summary.value[0]
1376
1377    actual_run_metadata = config_pb2.RunMetadata.FromString(
1378        first_val.tensor.string_val[0])
1379    self.assertProtoEquals(expected_run_metadata, actual_run_metadata)
1380
1381  @test_util.run_v2_only
1382  def testRunMetadataGraph_usesDefaultStep(self):
1383    meta = config_pb2.RunMetadata()
1384    try:
1385      summary_ops.set_step(42)
1386      event = self.run_metadata_graphs(name='my_name', data=meta)
1387      self.assertEqual(42, event.step)
1388    finally:
1389      # Reset to default state for other tests.
1390      summary_ops.set_step(None)
1391
1392  @test_util.run_v2_only
1393  def testTrace(self):
1394
1395    @def_function.function
1396    def f():
1397      x = constant_op.constant(2)
1398      y = constant_op.constant(3)
1399      return x**y
1400
1401    event = self.run_trace(f)
1402
1403    first_val = event.summary.value[0]
1404    actual_run_metadata = config_pb2.RunMetadata.FromString(
1405        first_val.tensor.string_val[0])
1406
1407    # Content of function_graphs is large and, for instance, device can change.
1408    self.assertTrue(hasattr(actual_run_metadata, 'function_graphs'))
1409
1410  @test_util.run_v2_only
1411  def testTrace_cannotEnableTraceInFunction(self):
1412
1413    @def_function.function
1414    def f():
1415      summary_ops.trace_on(graph=True, profiler=False)
1416      x = constant_op.constant(2)
1417      y = constant_op.constant(3)
1418      return x**y
1419
1420    with test.mock.patch.object(logging, 'warn') as mock_log:
1421      f()
1422      self.assertRegex(
1423          str(mock_log.call_args), 'Cannot enable trace inside a tf.function.')
1424
1425  @test_util.run_v2_only
1426  def testTrace_cannotEnableTraceInGraphMode(self):
1427    with test.mock.patch.object(logging, 'warn') as mock_log:
1428      with context.graph_mode():
1429        summary_ops.trace_on(graph=True, profiler=False)
1430      self.assertRegex(
1431          str(mock_log.call_args), 'Must enable trace in eager mode.')
1432
1433  @test_util.run_v2_only
1434  def testTrace_cannotExportTraceWithoutTrace(self):
1435    with six.assertRaisesRegex(self, ValueError,
1436                               'Must enable trace before export.'):
1437      summary_ops.trace_export(name='foo', step=1)
1438
1439  @test_util.run_v2_only
1440  def testTrace_cannotExportTraceInFunction(self):
1441    summary_ops.trace_on(graph=True, profiler=False)
1442
1443    @def_function.function
1444    def f():
1445      x = constant_op.constant(2)
1446      y = constant_op.constant(3)
1447      summary_ops.trace_export(name='foo', step=1)
1448      return x**y
1449
1450    with test.mock.patch.object(logging, 'warn') as mock_log:
1451      f()
1452      self.assertRegex(
1453          str(mock_log.call_args), 'Cannot export trace inside a tf.function.')
1454
1455  @test_util.run_v2_only
1456  def testTrace_cannotExportTraceInGraphMode(self):
1457    with test.mock.patch.object(logging, 'warn') as mock_log:
1458      with context.graph_mode():
1459        summary_ops.trace_export(name='foo', step=1)
1460      self.assertRegex(
1461          str(mock_log.call_args),
1462          'Can only export trace while executing eagerly.')
1463
1464  @test_util.run_v2_only
1465  def testTrace_usesDefaultStep(self):
1466
1467    @def_function.function
1468    def f():
1469      x = constant_op.constant(2)
1470      y = constant_op.constant(3)
1471      return x**y
1472
1473    try:
1474      summary_ops.set_step(42)
1475      event = self.run_trace(f, step=None)
1476      self.assertEqual(42, event.step)
1477    finally:
1478      # Reset to default state for other tests.
1479      summary_ops.set_step(None)
1480
1481  @test_util.run_v2_only
1482  def testTrace_withProfiler(self):
1483
1484    @def_function.function
1485    def f():
1486      x = constant_op.constant(2)
1487      y = constant_op.constant(3)
1488      return x**y
1489
1490    assert context.executing_eagerly()
1491    logdir = self.get_temp_dir()
1492    writer = summary_ops.create_file_writer_v2(logdir)
1493    summary_ops.trace_on(graph=True, profiler=True)
1494    profiler_outdir = self.get_temp_dir()
1495    with writer.as_default():
1496      f()
1497      summary_ops.trace_export(
1498          name='foo', step=1, profiler_outdir=profiler_outdir)
1499    writer.close()
1500
1501  @test_util.run_v2_only
1502  def testGraph_graph(self):
1503
1504    @def_function.function
1505    def f():
1506      x = constant_op.constant(2)
1507      y = constant_op.constant(3)
1508      return x**y
1509
1510    def summary_op_fn():
1511      summary_ops.graph(f.get_concrete_function().graph)
1512
1513    event = self.exec_summary_op(summary_op_fn)
1514    self.assertIsNotNone(event.graph_def)
1515
1516  @test_util.run_v2_only
1517  def testGraph_graphDef(self):
1518
1519    @def_function.function
1520    def f():
1521      x = constant_op.constant(2)
1522      y = constant_op.constant(3)
1523      return x**y
1524
1525    def summary_op_fn():
1526      summary_ops.graph(f.get_concrete_function().graph.as_graph_def())
1527
1528    event = self.exec_summary_op(summary_op_fn)
1529    self.assertIsNotNone(event.graph_def)
1530
1531  @test_util.run_v2_only
1532  def testGraph_invalidData(self):
1533    def summary_op_fn():
1534      summary_ops.graph('hello')
1535
1536    with self.assertRaisesRegex(
1537        ValueError,
1538        r'\'graph_data\' is not tf.Graph or tf.compat.v1.GraphDef',
1539    ):
1540      self.exec_summary_op(summary_op_fn)
1541
1542  @test_util.run_v2_only
1543  def testGraph_fromGraphMode(self):
1544
1545    @def_function.function
1546    def f():
1547      x = constant_op.constant(2)
1548      y = constant_op.constant(3)
1549      return x**y
1550
1551    @def_function.function
1552    def g(graph):
1553      summary_ops.graph(graph)
1554
1555    def summary_op_fn():
1556      graph_def = f.get_concrete_function().graph.as_graph_def(add_shapes=True)
1557      func_graph = constant_op.constant(graph_def.SerializeToString())
1558      g(func_graph)
1559
1560    with self.assertRaisesRegex(
1561        ValueError,
1562        r'graph\(\) cannot be invoked inside a graph context.',
1563    ):
1564      self.exec_summary_op(summary_op_fn)
1565
1566
1567def events_from_file(filepath):
1568  """Returns all events in a single event file.
1569
1570  Args:
1571    filepath: Path to the event file.
1572
1573  Returns:
1574    A list of all tf.Event protos in the event file.
1575  """
1576  records = list(tf_record.tf_record_iterator(filepath))
1577  result = []
1578  for r in records:
1579    event = event_pb2.Event()
1580    event.ParseFromString(r)
1581    result.append(event)
1582  return result
1583
1584
1585def events_from_logdir(logdir):
1586  """Returns all events in the single eventfile in logdir.
1587
1588  Args:
1589    logdir: The directory in which the single event file is sought.
1590
1591  Returns:
1592    A list of all tf.Event protos from the single event file.
1593
1594  Raises:
1595    AssertionError: If logdir does not contain exactly one file.
1596  """
1597  assert gfile.Exists(logdir)
1598  files = gfile.ListDirectory(logdir)
1599  assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
1600  return events_from_file(os.path.join(logdir, files[0]))
1601
1602
1603def events_from_multifile_logdir(logdir):
1604  """Returns map of filename to events for all `tfevents` files in the logdir.
1605
1606  Args:
1607    logdir: The directory from which to load events.
1608
1609  Returns:
1610    A dict mapping from relative filenames to lists of tf.Event protos.
1611
1612  Raises:
1613    AssertionError: If logdir does not contain exactly one file.
1614  """
1615  assert gfile.Exists(logdir)
1616  files = [file for file in gfile.ListDirectory(logdir) if 'tfevents' in file]
1617  return {file: events_from_file(os.path.join(logdir, file)) for file in files}
1618
1619
1620def to_numpy(summary_value):
1621  return tensor_util.MakeNdarray(summary_value.tensor)
1622
1623
1624if __name__ == '__main__':
1625  test.main()
1626