• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for V2 summary ops from summary_ops_v2."""
16
17import os
18import unittest
19
20from tensorflow.core.framework import graph_pb2
21from tensorflow.core.framework import node_def_pb2
22from tensorflow.core.framework import step_stats_pb2
23from tensorflow.core.framework import summary_pb2
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.core.util import event_pb2
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_spec
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.framework import test_util
35from tensorflow.python.lib.io import tf_record
36from tensorflow.python.module import module
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import summary_ops_v2 as summary_ops
39from tensorflow.python.ops import variables
40from tensorflow.python.platform import gfile
41from tensorflow.python.platform import test
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.saved_model import load as saved_model_load
44from tensorflow.python.saved_model import loader as saved_model_loader
45from tensorflow.python.saved_model import save as saved_model_save
46from tensorflow.python.saved_model import tag_constants
47
48
49class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
50
51  def testWrite(self):
52    logdir = self.get_temp_dir()
53    with context.eager_mode():
54      with summary_ops.create_file_writer_v2(logdir).as_default():
55        output = summary_ops.write('tag', 42, step=12)
56        self.assertTrue(output.numpy())
57    events = events_from_logdir(logdir)
58    self.assertEqual(2, len(events))
59    self.assertEqual(12, events[1].step)
60    value = events[1].summary.value[0]
61    self.assertEqual('tag', value.tag)
62    self.assertEqual(42, to_numpy(value))
63
64  def testWrite_fromFunction(self):
65    logdir = self.get_temp_dir()
66    with context.eager_mode():
67      writer = summary_ops.create_file_writer_v2(logdir)
68      @def_function.function
69      def f():
70        with writer.as_default():
71          return summary_ops.write('tag', 42, step=12)
72      output = f()
73      self.assertTrue(output.numpy())
74    events = events_from_logdir(logdir)
75    self.assertEqual(2, len(events))
76    self.assertEqual(12, events[1].step)
77    value = events[1].summary.value[0]
78    self.assertEqual('tag', value.tag)
79    self.assertEqual(42, to_numpy(value))
80
81  def testWrite_metadata(self):
82    logdir = self.get_temp_dir()
83    metadata = summary_pb2.SummaryMetadata()
84    metadata.plugin_data.plugin_name = 'foo'
85    with context.eager_mode():
86      with summary_ops.create_file_writer_v2(logdir).as_default():
87        summary_ops.write('obj', 0, 0, metadata=metadata)
88        summary_ops.write('bytes', 0, 0, metadata=metadata.SerializeToString())
89        m = constant_op.constant(metadata.SerializeToString())
90        summary_ops.write('string_tensor', 0, 0, metadata=m)
91    events = events_from_logdir(logdir)
92    self.assertEqual(4, len(events))
93    self.assertEqual(metadata, events[1].summary.value[0].metadata)
94    self.assertEqual(metadata, events[2].summary.value[0].metadata)
95    self.assertEqual(metadata, events[3].summary.value[0].metadata)
96
97  def testWrite_name(self):
98    @def_function.function
99    def f():
100      output = summary_ops.write('tag', 42, step=12, name='anonymous')
101      self.assertTrue(output.name.startswith('anonymous'))
102    f()
103
104  def testWrite_ndarray(self):
105    logdir = self.get_temp_dir()
106    with context.eager_mode():
107      with summary_ops.create_file_writer_v2(logdir).as_default():
108        summary_ops.write('tag', [[1, 2], [3, 4]], step=12)
109    events = events_from_logdir(logdir)
110    value = events[1].summary.value[0]
111    self.assertAllEqual([[1, 2], [3, 4]], to_numpy(value))
112
113  def testWrite_tensor(self):
114    logdir = self.get_temp_dir()
115    with context.eager_mode():
116      t = constant_op.constant([[1, 2], [3, 4]])
117      with summary_ops.create_file_writer_v2(logdir).as_default():
118        summary_ops.write('tag', t, step=12)
119      expected = t.numpy()
120    events = events_from_logdir(logdir)
121    value = events[1].summary.value[0]
122    self.assertAllEqual(expected, to_numpy(value))
123
124  def testWrite_tensor_fromFunction(self):
125    logdir = self.get_temp_dir()
126    with context.eager_mode():
127      writer = summary_ops.create_file_writer_v2(logdir)
128      @def_function.function
129      def f(t):
130        with writer.as_default():
131          summary_ops.write('tag', t, step=12)
132      t = constant_op.constant([[1, 2], [3, 4]])
133      f(t)
134      expected = t.numpy()
135    events = events_from_logdir(logdir)
136    value = events[1].summary.value[0]
137    self.assertAllEqual(expected, to_numpy(value))
138
139  def testWrite_stringTensor(self):
140    logdir = self.get_temp_dir()
141    with context.eager_mode():
142      with summary_ops.create_file_writer_v2(logdir).as_default():
143        summary_ops.write('tag', [b'foo', b'bar'], step=12)
144    events = events_from_logdir(logdir)
145    value = events[1].summary.value[0]
146    self.assertAllEqual([b'foo', b'bar'], to_numpy(value))
147
148  @test_util.run_gpu_only
149  def testWrite_gpuDeviceContext(self):
150    logdir = self.get_temp_dir()
151    with context.eager_mode():
152      with summary_ops.create_file_writer_v2(logdir).as_default():
153        with ops.device('/GPU:0'):
154          value = constant_op.constant(42.0)
155          step = constant_op.constant(12, dtype=dtypes.int64)
156          summary_ops.write('tag', value, step=step).numpy()
157    empty_metadata = summary_pb2.SummaryMetadata()
158    events = events_from_logdir(logdir)
159    self.assertEqual(2, len(events))
160    self.assertEqual(12, events[1].step)
161    self.assertEqual(42, to_numpy(events[1].summary.value[0]))
162    self.assertEqual(empty_metadata, events[1].summary.value[0].metadata)
163
164  @test_util.also_run_as_tf_function
165  def testWrite_noDefaultWriter(self):
166    # Use assertAllEqual instead of assertFalse since it works in a defun.
167    self.assertAllEqual(False, summary_ops.write('tag', 42, step=0))
168
169  @test_util.also_run_as_tf_function
170  def testWrite_noStep_okayIfAlsoNoDefaultWriter(self):
171    # Use assertAllEqual instead of assertFalse since it works in a defun.
172    self.assertAllEqual(False, summary_ops.write('tag', 42))
173
174  def testWrite_noStep(self):
175    logdir = self.get_temp_dir()
176    with context.eager_mode():
177      with summary_ops.create_file_writer_v2(logdir).as_default():
178        with self.assertRaisesRegex(ValueError, 'No step set'):
179          summary_ops.write('tag', 42)
180
181  def testWrite_noStep_okayIfNotRecordingSummaries(self):
182    logdir = self.get_temp_dir()
183    with context.eager_mode():
184      with summary_ops.create_file_writer_v2(logdir).as_default():
185        with summary_ops.record_if(False):
186          self.assertFalse(summary_ops.write('tag', 42))
187
188  def testWrite_usingDefaultStep(self):
189    logdir = self.get_temp_dir()
190    try:
191      with context.eager_mode():
192        with summary_ops.create_file_writer_v2(logdir).as_default():
193          summary_ops.set_step(1)
194          summary_ops.write('tag', 1.0)
195          summary_ops.set_step(2)
196          summary_ops.write('tag', 1.0)
197          mystep = variables.Variable(10, dtype=dtypes.int64)
198          summary_ops.set_step(mystep)
199          summary_ops.write('tag', 1.0)
200          mystep.assign_add(1)
201          summary_ops.write('tag', 1.0)
202      events = events_from_logdir(logdir)
203      self.assertEqual(5, len(events))
204      self.assertEqual(1, events[1].step)
205      self.assertEqual(2, events[2].step)
206      self.assertEqual(10, events[3].step)
207      self.assertEqual(11, events[4].step)
208    finally:
209      # Reset to default state for other tests.
210      summary_ops.set_step(None)
211
212  def testWrite_usingDefaultStepConstant_fromFunction(self):
213    logdir = self.get_temp_dir()
214    try:
215      with context.eager_mode():
216        writer = summary_ops.create_file_writer_v2(logdir)
217        @def_function.function
218        def f():
219          with writer.as_default():
220            summary_ops.write('tag', 1.0)
221        summary_ops.set_step(1)
222        f()
223        summary_ops.set_step(2)
224        f()
225      events = events_from_logdir(logdir)
226      self.assertEqual(3, len(events))
227      self.assertEqual(1, events[1].step)
228      # The step value will still be 1 because the value was captured at the
229      # time the function was first traced.
230      self.assertEqual(1, events[2].step)
231    finally:
232      # Reset to default state for other tests.
233      summary_ops.set_step(None)
234
235  def testWrite_usingDefaultStepVariable_fromFunction(self):
236    logdir = self.get_temp_dir()
237    try:
238      with context.eager_mode():
239        writer = summary_ops.create_file_writer_v2(logdir)
240        @def_function.function
241        def f():
242          with writer.as_default():
243            summary_ops.write('tag', 1.0)
244        mystep = variables.Variable(0, dtype=dtypes.int64)
245        summary_ops.set_step(mystep)
246        f()
247        mystep.assign_add(1)
248        f()
249        mystep.assign(10)
250        f()
251      events = events_from_logdir(logdir)
252      self.assertEqual(4, len(events))
253      self.assertEqual(0, events[1].step)
254      self.assertEqual(1, events[2].step)
255      self.assertEqual(10, events[3].step)
256    finally:
257      # Reset to default state for other tests.
258      summary_ops.set_step(None)
259
260  def testWrite_usingDefaultStepConstant_fromLegacyGraph(self):
261    logdir = self.get_temp_dir()
262    try:
263      with context.graph_mode():
264        writer = summary_ops.create_file_writer_v2(logdir)
265        summary_ops.set_step(1)
266        with writer.as_default():
267          write_op = summary_ops.write('tag', 1.0)
268        summary_ops.set_step(2)
269        with self.cached_session() as sess:
270          sess.run(writer.init())
271          sess.run(write_op)
272          sess.run(write_op)
273          sess.run(writer.flush())
274      events = events_from_logdir(logdir)
275      self.assertEqual(3, len(events))
276      self.assertEqual(1, events[1].step)
277      # The step value will still be 1 because the value was captured at the
278      # time the graph was constructed.
279      self.assertEqual(1, events[2].step)
280    finally:
281      # Reset to default state for other tests.
282      summary_ops.set_step(None)
283
284  def testWrite_usingDefaultStepVariable_fromLegacyGraph(self):
285    logdir = self.get_temp_dir()
286    try:
287      with context.graph_mode():
288        writer = summary_ops.create_file_writer_v2(logdir)
289        mystep = variables.Variable(0, dtype=dtypes.int64)
290        summary_ops.set_step(mystep)
291        with writer.as_default():
292          write_op = summary_ops.write('tag', 1.0)
293        first_assign_op = mystep.assign_add(1)
294        second_assign_op = mystep.assign(10)
295        with self.cached_session() as sess:
296          sess.run(writer.init())
297          sess.run(mystep.initializer)
298          sess.run(write_op)
299          sess.run(first_assign_op)
300          sess.run(write_op)
301          sess.run(second_assign_op)
302          sess.run(write_op)
303          sess.run(writer.flush())
304      events = events_from_logdir(logdir)
305      self.assertEqual(4, len(events))
306      self.assertEqual(0, events[1].step)
307      self.assertEqual(1, events[2].step)
308      self.assertEqual(10, events[3].step)
309    finally:
310      # Reset to default state for other tests.
311      summary_ops.set_step(None)
312
313  def testWrite_usingDefaultStep_fromAsDefault(self):
314    logdir = self.get_temp_dir()
315    try:
316      with context.eager_mode():
317        writer = summary_ops.create_file_writer_v2(logdir)
318        with writer.as_default(step=1):
319          summary_ops.write('tag', 1.0)
320          with writer.as_default():
321            summary_ops.write('tag', 1.0)
322            with writer.as_default(step=2):
323              summary_ops.write('tag', 1.0)
324            summary_ops.write('tag', 1.0)
325            summary_ops.set_step(3)
326          summary_ops.write('tag', 1.0)
327      events = events_from_logdir(logdir)
328      self.assertListEqual([1, 1, 2, 1, 3], [e.step for e in events[1:]])
329    finally:
330      # Reset to default state for other tests.
331      summary_ops.set_step(None)
332
333  def testWrite_usingDefaultStepVariable_fromAsDefault(self):
334    logdir = self.get_temp_dir()
335    try:
336      with context.eager_mode():
337        writer = summary_ops.create_file_writer_v2(logdir)
338        mystep = variables.Variable(1, dtype=dtypes.int64)
339        with writer.as_default(step=mystep):
340          summary_ops.write('tag', 1.0)
341          with writer.as_default():
342            mystep.assign(2)
343            summary_ops.write('tag', 1.0)
344            with writer.as_default(step=3):
345              summary_ops.write('tag', 1.0)
346            summary_ops.write('tag', 1.0)
347            mystep.assign(4)
348          summary_ops.write('tag', 1.0)
349      events = events_from_logdir(logdir)
350      self.assertListEqual([1, 2, 3, 2, 4], [e.step for e in events[1:]])
351    finally:
352      # Reset to default state for other tests.
353      summary_ops.set_step(None)
354
355  def testWrite_usingDefaultStep_fromSetAsDefault(self):
356    logdir = self.get_temp_dir()
357    try:
358      with context.eager_mode():
359        writer = summary_ops.create_file_writer_v2(logdir)
360        mystep = variables.Variable(1, dtype=dtypes.int64)
361        writer.set_as_default(step=mystep)
362        summary_ops.write('tag', 1.0)
363        mystep.assign(2)
364        summary_ops.write('tag', 1.0)
365        writer.set_as_default(step=3)
366        summary_ops.write('tag', 1.0)
367        writer.flush()
368      events = events_from_logdir(logdir)
369      self.assertListEqual([1, 2, 3], [e.step for e in events[1:]])
370    finally:
371      # Reset to default state for other tests.
372      summary_ops.set_step(None)
373
374  def testWrite_usingDefaultStepVariable_fromSetAsDefault(self):
375    logdir = self.get_temp_dir()
376    try:
377      with context.eager_mode():
378        writer = summary_ops.create_file_writer_v2(logdir)
379        writer.set_as_default(step=1)
380        summary_ops.write('tag', 1.0)
381        writer.set_as_default(step=2)
382        summary_ops.write('tag', 1.0)
383        writer.set_as_default()
384        summary_ops.write('tag', 1.0)
385        writer.flush()
386      events = events_from_logdir(logdir)
387      self.assertListEqual([1, 2, 2], [e.step for e in events[1:]])
388    finally:
389      # Reset to default state for other tests.
390      summary_ops.set_step(None)
391
392  def testWrite_recordIf_constant(self):
393    logdir = self.get_temp_dir()
394    with context.eager_mode():
395      with summary_ops.create_file_writer_v2(logdir).as_default():
396        self.assertTrue(summary_ops.write('default', 1, step=0))
397        with summary_ops.record_if(True):
398          self.assertTrue(summary_ops.write('set_on', 1, step=0))
399        with summary_ops.record_if(False):
400          self.assertFalse(summary_ops.write('set_off', 1, step=0))
401    events = events_from_logdir(logdir)
402    self.assertEqual(3, len(events))
403    self.assertEqual('default', events[1].summary.value[0].tag)
404    self.assertEqual('set_on', events[2].summary.value[0].tag)
405
406  def testWrite_recordIf_constant_fromFunction(self):
407    logdir = self.get_temp_dir()
408    with context.eager_mode():
409      writer = summary_ops.create_file_writer_v2(logdir)
410      @def_function.function
411      def f():
412        with writer.as_default():
413          # Use assertAllEqual instead of assertTrue since it works in a defun.
414          self.assertAllEqual(summary_ops.write('default', 1, step=0), True)
415          with summary_ops.record_if(True):
416            self.assertAllEqual(summary_ops.write('set_on', 1, step=0), True)
417          with summary_ops.record_if(False):
418            self.assertAllEqual(summary_ops.write('set_off', 1, step=0), False)
419      f()
420    events = events_from_logdir(logdir)
421    self.assertEqual(3, len(events))
422    self.assertEqual('default', events[1].summary.value[0].tag)
423    self.assertEqual('set_on', events[2].summary.value[0].tag)
424
425  def testWrite_recordIf_callable(self):
426    logdir = self.get_temp_dir()
427    with context.eager_mode():
428      step = variables.Variable(-1, dtype=dtypes.int64)
429      def record_fn():
430        step.assign_add(1)
431        return int(step % 2) == 0
432      with summary_ops.create_file_writer_v2(logdir).as_default():
433        with summary_ops.record_if(record_fn):
434          self.assertTrue(summary_ops.write('tag', 1, step=step))
435          self.assertFalse(summary_ops.write('tag', 1, step=step))
436          self.assertTrue(summary_ops.write('tag', 1, step=step))
437          self.assertFalse(summary_ops.write('tag', 1, step=step))
438          self.assertTrue(summary_ops.write('tag', 1, step=step))
439    events = events_from_logdir(logdir)
440    self.assertEqual(4, len(events))
441    self.assertEqual(0, events[1].step)
442    self.assertEqual(2, events[2].step)
443    self.assertEqual(4, events[3].step)
444
445  def testWrite_recordIf_callable_fromFunction(self):
446    logdir = self.get_temp_dir()
447    with context.eager_mode():
448      writer = summary_ops.create_file_writer_v2(logdir)
449      step = variables.Variable(-1, dtype=dtypes.int64)
450      @def_function.function
451      def record_fn():
452        step.assign_add(1)
453        return math_ops.equal(step % 2, 0)
454      @def_function.function
455      def f():
456        with writer.as_default():
457          with summary_ops.record_if(record_fn):
458            return [
459                summary_ops.write('tag', 1, step=step),
460                summary_ops.write('tag', 1, step=step),
461                summary_ops.write('tag', 1, step=step)]
462      self.assertAllEqual(f(), [True, False, True])
463      self.assertAllEqual(f(), [False, True, False])
464    events = events_from_logdir(logdir)
465    self.assertEqual(4, len(events))
466    self.assertEqual(0, events[1].step)
467    self.assertEqual(2, events[2].step)
468    self.assertEqual(4, events[3].step)
469
470  def testWrite_recordIf_tensorInput_fromFunction(self):
471    logdir = self.get_temp_dir()
472    with context.eager_mode():
473      writer = summary_ops.create_file_writer_v2(logdir)
474      @def_function.function(input_signature=[
475          tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)])
476      def f(step):
477        with writer.as_default():
478          with summary_ops.record_if(math_ops.equal(step % 2, 0)):
479            return summary_ops.write('tag', 1, step=step)
480      self.assertTrue(f(0))
481      self.assertFalse(f(1))
482      self.assertTrue(f(2))
483      self.assertFalse(f(3))
484      self.assertTrue(f(4))
485    events = events_from_logdir(logdir)
486    self.assertEqual(4, len(events))
487    self.assertEqual(0, events[1].step)
488    self.assertEqual(2, events[2].step)
489    self.assertEqual(4, events[3].step)
490
491  def testWriteRawPb(self):
492    logdir = self.get_temp_dir()
493    pb = summary_pb2.Summary()
494    pb.value.add().simple_value = 42.0
495    with context.eager_mode():
496      with summary_ops.create_file_writer_v2(logdir).as_default():
497        output = summary_ops.write_raw_pb(pb.SerializeToString(), step=12)
498        self.assertTrue(output.numpy())
499    events = events_from_logdir(logdir)
500    self.assertEqual(2, len(events))
501    self.assertEqual(12, events[1].step)
502    self.assertProtoEquals(pb, events[1].summary)
503
504  def testWriteRawPb_fromFunction(self):
505    logdir = self.get_temp_dir()
506    pb = summary_pb2.Summary()
507    pb.value.add().simple_value = 42.0
508    with context.eager_mode():
509      writer = summary_ops.create_file_writer_v2(logdir)
510      @def_function.function
511      def f():
512        with writer.as_default():
513          return summary_ops.write_raw_pb(pb.SerializeToString(), step=12)
514      output = f()
515      self.assertTrue(output.numpy())
516    events = events_from_logdir(logdir)
517    self.assertEqual(2, len(events))
518    self.assertEqual(12, events[1].step)
519    self.assertProtoEquals(pb, events[1].summary)
520
521  def testWriteRawPb_multipleValues(self):
522    logdir = self.get_temp_dir()
523    pb1 = summary_pb2.Summary()
524    pb1.value.add().simple_value = 1.0
525    pb1.value.add().simple_value = 2.0
526    pb2 = summary_pb2.Summary()
527    pb2.value.add().simple_value = 3.0
528    pb3 = summary_pb2.Summary()
529    pb3.value.add().simple_value = 4.0
530    pb3.value.add().simple_value = 5.0
531    pb3.value.add().simple_value = 6.0
532    pbs = [pb.SerializeToString() for pb in (pb1, pb2, pb3)]
533    with context.eager_mode():
534      with summary_ops.create_file_writer_v2(logdir).as_default():
535        output = summary_ops.write_raw_pb(pbs, step=12)
536        self.assertTrue(output.numpy())
537    events = events_from_logdir(logdir)
538    self.assertEqual(2, len(events))
539    self.assertEqual(12, events[1].step)
540    expected_pb = summary_pb2.Summary()
541    for i in range(6):
542      expected_pb.value.add().simple_value = i + 1.0
543    self.assertProtoEquals(expected_pb, events[1].summary)
544
545  def testWriteRawPb_invalidValue(self):
546    logdir = self.get_temp_dir()
547    with context.eager_mode():
548      with summary_ops.create_file_writer_v2(logdir).as_default():
549        with self.assertRaisesRegex(
550            errors.DataLossError,
551            'Bad tf.compat.v1.Summary binary proto tensor string'):
552          summary_ops.write_raw_pb('notaproto', step=12)
553
554  @test_util.also_run_as_tf_function
555  def testGetSetStep(self):
556    try:
557      self.assertIsNone(summary_ops.get_step())
558      summary_ops.set_step(1)
559      # Use assertAllEqual instead of assertEqual since it works in a defun.
560      self.assertAllEqual(1, summary_ops.get_step())
561      summary_ops.set_step(constant_op.constant(2))
562      self.assertAllEqual(2, summary_ops.get_step())
563    finally:
564      # Reset to default state for other tests.
565      summary_ops.set_step(None)
566
567  def testGetSetStep_variable(self):
568    with context.eager_mode():
569      try:
570        mystep = variables.Variable(0)
571        summary_ops.set_step(mystep)
572        self.assertAllEqual(0, summary_ops.get_step().read_value())
573        mystep.assign_add(1)
574        self.assertAllEqual(1, summary_ops.get_step().read_value())
575        # Check that set_step() properly maintains reference to variable.
576        del mystep
577        self.assertAllEqual(1, summary_ops.get_step().read_value())
578        summary_ops.get_step().assign_add(1)
579        self.assertAllEqual(2, summary_ops.get_step().read_value())
580      finally:
581        # Reset to default state for other tests.
582        summary_ops.set_step(None)
583
584  def testGetSetStep_variable_fromFunction(self):
585    with context.eager_mode():
586      try:
587        @def_function.function
588        def set_step(step):
589          summary_ops.set_step(step)
590          return summary_ops.get_step()
591        @def_function.function
592        def get_and_increment():
593          summary_ops.get_step().assign_add(1)
594          return summary_ops.get_step()
595        mystep = variables.Variable(0)
596        self.assertAllEqual(0, set_step(mystep))
597        self.assertAllEqual(0, summary_ops.get_step().read_value())
598        self.assertAllEqual(1, get_and_increment())
599        self.assertAllEqual(2, get_and_increment())
600        # Check that set_step() properly maintains reference to variable.
601        del mystep
602        self.assertAllEqual(3, get_and_increment())
603      finally:
604        # Reset to default state for other tests.
605        summary_ops.set_step(None)
606
607  @test_util.also_run_as_tf_function
608  def testSummaryScope(self):
609    with summary_ops.summary_scope('foo') as (tag, scope):
610      self.assertEqual('foo', tag)
611      self.assertEqual('foo/', scope)
612      with summary_ops.summary_scope('bar') as (tag, scope):
613        self.assertEqual('foo/bar', tag)
614        self.assertEqual('foo/bar/', scope)
615      with summary_ops.summary_scope('with/slash') as (tag, scope):
616        self.assertEqual('foo/with/slash', tag)
617        self.assertEqual('foo/with/slash/', scope)
618      with ops.name_scope(None, skip_on_eager=False):
619        with summary_ops.summary_scope('unnested') as (tag, scope):
620          self.assertEqual('unnested', tag)
621          self.assertEqual('unnested/', scope)
622
623  @test_util.also_run_as_tf_function
624  def testSummaryScope_defaultName(self):
625    with summary_ops.summary_scope(None) as (tag, scope):
626      self.assertEqual('summary', tag)
627      self.assertEqual('summary/', scope)
628    with summary_ops.summary_scope(None, 'backup') as (tag, scope):
629      self.assertEqual('backup', tag)
630      self.assertEqual('backup/', scope)
631
632  @test_util.also_run_as_tf_function
633  def testSummaryScope_handlesCharactersIllegalForScope(self):
634    with summary_ops.summary_scope('f?o?o') as (tag, scope):
635      self.assertEqual('f?o?o', tag)
636      self.assertEqual('foo/', scope)
637    # If all characters aren't legal for a scope name, use default name.
638    with summary_ops.summary_scope('???', 'backup') as (tag, scope):
639      self.assertEqual('???', tag)
640      self.assertEqual('backup/', scope)
641
642  @test_util.also_run_as_tf_function
643  def testSummaryScope_nameNotUniquifiedForTag(self):
644    constant_op.constant(0, name='foo')
645    with summary_ops.summary_scope('foo') as (tag, _):
646      self.assertEqual('foo', tag)
647    with summary_ops.summary_scope('foo') as (tag, _):
648      self.assertEqual('foo', tag)
649    with ops.name_scope('with', skip_on_eager=False):
650      constant_op.constant(0, name='slash')
651    with summary_ops.summary_scope('with/slash') as (tag, _):
652      self.assertEqual('with/slash', tag)
653
654  def testAllV2SummaryOps(self):
655    logdir = self.get_temp_dir()
656    def define_ops():
657      result = []
658      # TF 2.0 summary ops
659      result.append(summary_ops.write('write', 1, step=0))
660      result.append(summary_ops.write_raw_pb(b'', step=0, name='raw_pb'))
661      # TF 1.x tf.contrib.summary ops
662      result.append(summary_ops.generic('tensor', 1, step=1))
663      result.append(summary_ops.scalar('scalar', 2.0, step=1))
664      result.append(summary_ops.histogram('histogram', [1.0], step=1))
665      result.append(summary_ops.image('image', [[[[1.0]]]], step=1))
666      result.append(summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1))
667      return result
668    with context.graph_mode():
669      ops_without_writer = define_ops()
670      with summary_ops.create_file_writer_v2(logdir).as_default():
671        with summary_ops.record_if(True):
672          ops_recording_on = define_ops()
673        with summary_ops.record_if(False):
674          ops_recording_off = define_ops()
675      # We should be collecting all ops defined with a default writer present,
676      # regardless of whether recording was set on or off, but not those defined
677      # without a writer at all.
678      del ops_without_writer
679      expected_ops = ops_recording_on + ops_recording_off
680      self.assertCountEqual(expected_ops, summary_ops.all_v2_summary_ops())
681
682  def testShouldRecordSummaries_defaultState(self):
683    logdir = self.get_temp_dir()
684    with context.eager_mode():
685      self.assertAllEqual(False, summary_ops.should_record_summaries())
686      w = summary_ops.create_file_writer_v2(logdir)
687      self.assertAllEqual(False, summary_ops.should_record_summaries())
688      with w.as_default():
689        # Should be enabled only when default writer is registered.
690        self.assertAllEqual(True, summary_ops.should_record_summaries())
691      self.assertAllEqual(False, summary_ops.should_record_summaries())
692      with summary_ops.record_if(True):
693        # Should be disabled when no default writer, even with record_if(True).
694        self.assertAllEqual(False, summary_ops.should_record_summaries())
695
696  def testShouldRecordSummaries_constants(self):
697    logdir = self.get_temp_dir()
698    with context.eager_mode():
699      with summary_ops.create_file_writer_v2(logdir).as_default():
700        with summary_ops.record_if(True):
701          self.assertAllEqual(True, summary_ops.should_record_summaries())
702        with summary_ops.record_if(False):
703          self.assertAllEqual(False, summary_ops.should_record_summaries())
704          with summary_ops.record_if(True):
705            self.assertAllEqual(True, summary_ops.should_record_summaries())
706
707  def testShouldRecordSummaries_variable(self):
708    logdir = self.get_temp_dir()
709    with context.eager_mode():
710      with summary_ops.create_file_writer_v2(logdir).as_default():
711        cond = variables.Variable(False)
712        with summary_ops.record_if(cond):
713          self.assertAllEqual(False, summary_ops.should_record_summaries())
714          cond.assign(True)
715          self.assertAllEqual(True, summary_ops.should_record_summaries())
716
717  def testShouldRecordSummaries_callable(self):
718    logdir = self.get_temp_dir()
719    with context.eager_mode():
720      with summary_ops.create_file_writer_v2(logdir).as_default():
721        cond_box = [False]
722        cond = lambda: cond_box[0]
723        with summary_ops.record_if(cond):
724          self.assertAllEqual(False, summary_ops.should_record_summaries())
725          cond_box[0] = True
726          self.assertAllEqual(True, summary_ops.should_record_summaries())
727
728  def testShouldRecordSummaries_fromFunction(self):
729    logdir = self.get_temp_dir()
730    with context.eager_mode():
731      writer = summary_ops.create_file_writer_v2(logdir)
732      @def_function.function(input_signature=[
733          tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool)])
734      def f(cond):
735        results = []
736        results.append(summary_ops.should_record_summaries())
737        with writer.as_default():
738          results.append(summary_ops.should_record_summaries())
739          with summary_ops.record_if(False):
740            results.append(summary_ops.should_record_summaries())
741          with summary_ops.record_if(cond):
742            results.append(summary_ops.should_record_summaries())
743        return results
744      self.assertAllEqual([False, True, False, True], f(True))
745      self.assertAllEqual([False, True, False, False], f(False))
746
747  def testHasDefaultWriter_checkWriter(self):
748    logdir = self.get_temp_dir()
749    with context.eager_mode():
750      with self.subTest(name='has_writer'):
751        with summary_ops.create_file_writer_v2(logdir).as_default():
752          self.assertTrue(summary_ops.has_default_writer())
753      with self.subTest(name='no_writer'):
754        self.assertFalse(summary_ops.has_default_writer())
755
756
757class SummaryWriterTest(test_util.TensorFlowTestCase):
758
759  def testCreate_withInitAndClose(self):
760    logdir = self.get_temp_dir()
761    with context.eager_mode():
762      writer = summary_ops.create_file_writer_v2(
763          logdir, max_queue=1000, flush_millis=1000000)
764      get_total = lambda: len(events_from_logdir(logdir))
765      self.assertEqual(1, get_total())  # file_version Event
766      # Calling init() again while writer is open has no effect
767      writer.init()
768      self.assertEqual(1, get_total())
769      with writer.as_default():
770        summary_ops.write('tag', 1, step=0)
771        self.assertEqual(1, get_total())
772        # Calling .close() should do an implicit flush
773        writer.close()
774        self.assertEqual(2, get_total())
775
776  def testCreate_fromFunction(self):
777    logdir = self.get_temp_dir()
778    @def_function.function
779    def f():
780      # Returned SummaryWriter must be stored in a non-local variable so it
781      # lives throughout the function execution.
782      if not hasattr(f, 'writer'):
783        f.writer = summary_ops.create_file_writer_v2(logdir)
784    with context.eager_mode():
785      f()
786    event_files = gfile.Glob(os.path.join(logdir, '*'))
787    self.assertEqual(1, len(event_files))
788
789  def testCreate_graphTensorArgument_raisesError(self):
790    logdir = self.get_temp_dir()
791    with context.graph_mode():
792      logdir_tensor = constant_op.constant(logdir)
793    with context.eager_mode():
794      with self.assertRaisesRegex(
795          ValueError, 'Invalid graph Tensor argument.*logdir'):
796        summary_ops.create_file_writer_v2(logdir_tensor)
797    self.assertEmpty(gfile.Glob(os.path.join(logdir, '*')))
798
799  def testCreate_fromFunction_graphTensorArgument_raisesError(self):
800    logdir = self.get_temp_dir()
801    @def_function.function
802    def f():
803      summary_ops.create_file_writer_v2(constant_op.constant(logdir))
804    with context.eager_mode():
805      with self.assertRaisesRegex(
806          ValueError, 'Invalid graph Tensor argument.*logdir'):
807        f()
808    self.assertEmpty(gfile.Glob(os.path.join(logdir, '*')))
809
810  def testCreate_fromFunction_unpersistedResource_raisesError(self):
811    logdir = self.get_temp_dir()
812    @def_function.function
813    def f():
814      with summary_ops.create_file_writer_v2(logdir).as_default():
815        pass  # Calling .as_default() is enough to indicate use.
816    with context.eager_mode():
817      # TODO(nickfelt): change this to a better error
818      with self.assertRaisesRegex(
819          errors.NotFoundError, 'Resource.*does not exist'):
820        f()
821    # Even though we didn't use it, an event file will have been created.
822    self.assertEqual(1, len(gfile.Glob(os.path.join(logdir, '*'))))
823
824  def testCreate_immediateSetAsDefault_retainsReference(self):
825    logdir = self.get_temp_dir()
826    try:
827      with context.eager_mode():
828        summary_ops.create_file_writer_v2(logdir).set_as_default()
829        summary_ops.flush()
830    finally:
831      # Ensure we clean up no matter how the test executes.
832      summary_ops._summary_state.writer = None  # pylint: disable=protected-access
833
834  def testCreate_immediateAsDefault_retainsReference(self):
835    logdir = self.get_temp_dir()
836    with context.eager_mode():
837      with summary_ops.create_file_writer_v2(logdir).as_default():
838        summary_ops.flush()
839
840  def testCreate_avoidsFilenameCollision(self):
841    logdir = self.get_temp_dir()
842    with context.eager_mode():
843      for _ in range(10):
844        summary_ops.create_file_writer_v2(logdir)
845    event_files = gfile.Glob(os.path.join(logdir, '*'))
846    self.assertLen(event_files, 10)
847
848  def testCreate_graphMode_avoidsFilenameCollision(self):
849    logdir = self.get_temp_dir()
850    with context.graph_mode(), ops.Graph().as_default():
851      writer = summary_ops.create_file_writer_v2(logdir)
852      with self.cached_session() as sess:
853        for _ in range(10):
854          sess.run(writer.init())
855          sess.run(writer.close())
856    event_files = gfile.Glob(os.path.join(logdir, '*'))
857    self.assertLen(event_files, 10)
858
859  def testNoSharing(self):
860    # Two writers with the same logdir should not share state.
861    logdir = self.get_temp_dir()
862    with context.eager_mode():
863      writer1 = summary_ops.create_file_writer_v2(logdir)
864      with writer1.as_default():
865        summary_ops.write('tag', 1, step=1)
866      event_files = gfile.Glob(os.path.join(logdir, '*'))
867      self.assertEqual(1, len(event_files))
868      file1 = event_files[0]
869
870      writer2 = summary_ops.create_file_writer_v2(logdir)
871      with writer2.as_default():
872        summary_ops.write('tag', 1, step=2)
873      event_files = gfile.Glob(os.path.join(logdir, '*'))
874      self.assertEqual(2, len(event_files))
875      event_files.remove(file1)
876      file2 = event_files[0]
877
878      # Extra writes to ensure interleaved usage works.
879      with writer1.as_default():
880        summary_ops.write('tag', 1, step=1)
881      with writer2.as_default():
882        summary_ops.write('tag', 1, step=2)
883
884    events = iter(events_from_file(file1))
885    self.assertEqual('brain.Event:2', next(events).file_version)
886    self.assertEqual(1, next(events).step)
887    self.assertEqual(1, next(events).step)
888    self.assertRaises(StopIteration, lambda: next(events))
889    events = iter(events_from_file(file2))
890    self.assertEqual('brain.Event:2', next(events).file_version)
891    self.assertEqual(2, next(events).step)
892    self.assertEqual(2, next(events).step)
893    self.assertRaises(StopIteration, lambda: next(events))
894
895  def testNoSharing_fromFunction(self):
896    logdir = self.get_temp_dir()
897    @def_function.function
898    def f1():
899      if not hasattr(f1, 'writer'):
900        f1.writer = summary_ops.create_file_writer_v2(logdir)
901      with f1.writer.as_default():
902        summary_ops.write('tag', 1, step=1)
903    @def_function.function
904    def f2():
905      if not hasattr(f2, 'writer'):
906        f2.writer = summary_ops.create_file_writer_v2(logdir)
907      with f2.writer.as_default():
908        summary_ops.write('tag', 1, step=2)
909    with context.eager_mode():
910      f1()
911      event_files = gfile.Glob(os.path.join(logdir, '*'))
912      self.assertEqual(1, len(event_files))
913      file1 = event_files[0]
914
915      f2()
916      event_files = gfile.Glob(os.path.join(logdir, '*'))
917      self.assertEqual(2, len(event_files))
918      event_files.remove(file1)
919      file2 = event_files[0]
920
921      # Extra writes to ensure interleaved usage works.
922      f1()
923      f2()
924
925    events = iter(events_from_file(file1))
926    self.assertEqual('brain.Event:2', next(events).file_version)
927    self.assertEqual(1, next(events).step)
928    self.assertEqual(1, next(events).step)
929    self.assertRaises(StopIteration, lambda: next(events))
930    events = iter(events_from_file(file2))
931    self.assertEqual('brain.Event:2', next(events).file_version)
932    self.assertEqual(2, next(events).step)
933    self.assertEqual(2, next(events).step)
934    self.assertRaises(StopIteration, lambda: next(events))
935
936  def testMaxQueue(self):
937    logdir = self.get_temp_dir()
938    with context.eager_mode():
939      with summary_ops.create_file_writer_v2(
940          logdir, max_queue=1, flush_millis=999999).as_default():
941        get_total = lambda: len(events_from_logdir(logdir))
942        # Note: First tf.compat.v1.Event is always file_version.
943        self.assertEqual(1, get_total())
944        summary_ops.write('tag', 1, step=0)
945        self.assertEqual(1, get_total())
946        # Should flush after second summary since max_queue = 1
947        summary_ops.write('tag', 1, step=0)
948        self.assertEqual(3, get_total())
949
950  def testWriterFlush(self):
951    logdir = self.get_temp_dir()
952    get_total = lambda: len(events_from_logdir(logdir))
953    with context.eager_mode():
954      writer = summary_ops.create_file_writer_v2(
955          logdir, max_queue=1000, flush_millis=1000000)
956      self.assertEqual(1, get_total())  # file_version Event
957      with writer.as_default():
958        summary_ops.write('tag', 1, step=0)
959        self.assertEqual(1, get_total())
960        writer.flush()
961        self.assertEqual(2, get_total())
962        summary_ops.write('tag', 1, step=0)
963        self.assertEqual(2, get_total())
964      # Exiting the "as_default()" should do an implicit flush
965      self.assertEqual(3, get_total())
966
967  def testFlushFunction(self):
968    logdir = self.get_temp_dir()
969    with context.eager_mode():
970      writer = summary_ops.create_file_writer_v2(
971          logdir, max_queue=999999, flush_millis=999999)
972      with writer.as_default():
973        get_total = lambda: len(events_from_logdir(logdir))
974        # Note: First tf.compat.v1.Event is always file_version.
975        self.assertEqual(1, get_total())
976        summary_ops.write('tag', 1, step=0)
977        summary_ops.write('tag', 1, step=0)
978        self.assertEqual(1, get_total())
979        summary_ops.flush()
980        self.assertEqual(3, get_total())
981        # Test "writer" parameter
982        summary_ops.write('tag', 1, step=0)
983        self.assertEqual(3, get_total())
984        summary_ops.flush(writer=writer)
985        self.assertEqual(4, get_total())
986
987  # Regression test for b/228097117.
988  def testFlushFunction_disallowsInvalidWriterInput(self):
989    with context.eager_mode():
990      with self.assertRaisesRegex(ValueError, 'Invalid argument to flush'):
991        summary_ops.flush(writer=())
992
993  @test_util.assert_no_new_tensors
994  def testNoMemoryLeak_graphMode(self):
995    logdir = self.get_temp_dir()
996    with context.graph_mode(), ops.Graph().as_default():
997      summary_ops.create_file_writer_v2(logdir)
998
999  @test_util.assert_no_new_pyobjects_executing_eagerly
1000  def testNoMemoryLeak_eagerMode(self):
1001    logdir = self.get_temp_dir()
1002    with summary_ops.create_file_writer_v2(logdir).as_default():
1003      summary_ops.write('tag', 1, step=0)
1004
1005  def testClose_preventsLaterUse(self):
1006    logdir = self.get_temp_dir()
1007    with context.eager_mode():
1008      writer = summary_ops.create_file_writer_v2(logdir)
1009      writer.close()
1010      writer.close()  # redundant close() is a no-op
1011      writer.flush()  # redundant flush() is a no-op
1012      with self.assertRaisesRegex(RuntimeError, 'already closed'):
1013        writer.init()
1014      with self.assertRaisesRegex(RuntimeError, 'already closed'):
1015        with writer.as_default():
1016          self.fail('should not get here')
1017      with self.assertRaisesRegex(RuntimeError, 'already closed'):
1018        writer.set_as_default()
1019
1020  def testClose_closesOpenFile(self):
1021    try:
1022      import psutil  # pylint: disable=g-import-not-at-top
1023    except ImportError:
1024      raise unittest.SkipTest('test requires psutil')
1025    proc = psutil.Process()
1026    get_open_filenames = lambda: set(info[0] for info in proc.open_files())
1027    logdir = self.get_temp_dir()
1028    with context.eager_mode():
1029      writer = summary_ops.create_file_writer_v2(logdir)
1030      files = gfile.Glob(os.path.join(logdir, '*'))
1031      self.assertEqual(1, len(files))
1032      eventfile = files[0]
1033      self.assertIn(eventfile, get_open_filenames())
1034      writer.close()
1035      self.assertNotIn(eventfile, get_open_filenames())
1036
1037  def testDereference_closesOpenFile(self):
1038    try:
1039      import psutil  # pylint: disable=g-import-not-at-top
1040    except ImportError:
1041      raise unittest.SkipTest('test requires psutil')
1042    proc = psutil.Process()
1043    get_open_filenames = lambda: set(info[0] for info in proc.open_files())
1044    logdir = self.get_temp_dir()
1045    with context.eager_mode():
1046      writer = summary_ops.create_file_writer_v2(logdir)
1047      files = gfile.Glob(os.path.join(logdir, '*'))
1048      self.assertEqual(1, len(files))
1049      eventfile = files[0]
1050      self.assertIn(eventfile, get_open_filenames())
1051      del writer
1052      self.assertNotIn(eventfile, get_open_filenames())
1053
1054
1055class SummaryWriterSavedModelTest(test_util.TensorFlowTestCase):
1056
1057  def testWriter_savedAsModuleProperty_loadInEagerMode(self):
1058    with context.eager_mode():
1059      class Model(module.Module):
1060
1061        def __init__(self, model_dir):
1062          self._writer = summary_ops.create_file_writer_v2(
1063              model_dir, experimental_trackable=True)
1064
1065        @def_function.function(input_signature=[
1066            tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)
1067        ])
1068        def train(self, step):
1069          with self._writer.as_default():
1070            summary_ops.write('tag', 'foo', step=step)
1071          return constant_op.constant(0)
1072
1073      logdir = self.get_temp_dir()
1074      to_export = Model(logdir)
1075      pre_save_files = set(events_from_multifile_logdir(logdir))
1076      export_dir = os.path.join(logdir, 'export')
1077      saved_model_save.save(
1078          to_export, export_dir, signatures={'train': to_export.train})
1079
1080    # Reset context to ensure we don't share any resources with saving code.
1081    context._reset_context()  # pylint: disable=protected-access
1082    with context.eager_mode():
1083      restored = saved_model_load.load(export_dir)
1084      restored.train(1)
1085      restored.train(2)
1086      post_restore_files = set(events_from_multifile_logdir(logdir))
1087      restored2 = saved_model_load.load(export_dir)
1088      restored2.train(3)
1089      restored2.train(4)
1090      files_to_events = events_from_multifile_logdir(logdir)
1091      post_restore2_files = set(files_to_events)
1092      self.assertLen(files_to_events, 3)
1093      def unwrap_singleton(iterable):
1094        self.assertLen(iterable, 1)
1095        return next(iter(iterable))
1096      restore_file = unwrap_singleton(post_restore_files - pre_save_files)
1097      restore2_file = unwrap_singleton(post_restore2_files - post_restore_files)
1098      restore_events = files_to_events[restore_file]
1099      restore2_events = files_to_events[restore2_file]
1100      self.assertLen(restore_events, 3)
1101      self.assertEqual(1, restore_events[1].step)
1102      self.assertEqual(2, restore_events[2].step)
1103      self.assertLen(restore2_events, 3)
1104      self.assertEqual(3, restore2_events[1].step)
1105      self.assertEqual(4, restore2_events[2].step)
1106
1107  def testWriter_savedAsModuleProperty_loadInGraphMode(self):
1108    with context.eager_mode():
1109
1110      class Model(module.Module):
1111
1112        def __init__(self, model_dir):
1113          self._writer = summary_ops.create_file_writer_v2(
1114              model_dir, experimental_trackable=True)
1115
1116        @def_function.function(input_signature=[
1117            tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)
1118        ])
1119        def train(self, step):
1120          with self._writer.as_default():
1121            summary_ops.write('tag', 'foo', step=step)
1122          return constant_op.constant(0)
1123
1124      logdir = self.get_temp_dir()
1125      to_export = Model(logdir)
1126      pre_save_files = set(events_from_multifile_logdir(logdir))
1127      export_dir = os.path.join(logdir, 'export')
1128      saved_model_save.save(
1129          to_export, export_dir, signatures={'train': to_export.train})
1130
1131    # Reset context to ensure we don't share any resources with saving code.
1132    context._reset_context()  # pylint: disable=protected-access
1133
1134    def load_and_run_model(sess, input_values):
1135      """Load and run the SavedModel signature in the TF 1.x style."""
1136      model = saved_model_loader.load(sess, [tag_constants.SERVING], export_dir)
1137      signature = model.signature_def['train']
1138      inputs = list(signature.inputs.values())
1139      assert len(inputs) == 1, inputs
1140      outputs = list(signature.outputs.values())
1141      assert len(outputs) == 1, outputs
1142      input_tensor = sess.graph.get_tensor_by_name(inputs[0].name)
1143      output_tensor = sess.graph.get_tensor_by_name(outputs[0].name)
1144      for v in input_values:
1145        sess.run(output_tensor, feed_dict={input_tensor: v})
1146
1147    with context.graph_mode(), ops.Graph().as_default():
1148      # Since writer shared_name is fixed, within a single session, all loads of
1149      # this SavedModel will refer to a single writer resouce, so it will be
1150      # initialized only once and write to a single file.
1151      with self.session() as sess:
1152        load_and_run_model(sess, [1, 2])
1153        load_and_run_model(sess, [3, 4])
1154      post_restore_files = set(events_from_multifile_logdir(logdir))
1155      # New session will recreate the resource and write to a second file.
1156      with self.session() as sess:
1157        load_and_run_model(sess, [5, 6])
1158      files_to_events = events_from_multifile_logdir(logdir)
1159      post_restore2_files = set(files_to_events)
1160
1161    self.assertLen(files_to_events, 3)
1162    def unwrap_singleton(iterable):
1163      self.assertLen(iterable, 1)
1164      return next(iter(iterable))
1165    restore_file = unwrap_singleton(post_restore_files - pre_save_files)
1166    restore2_file = unwrap_singleton(post_restore2_files - post_restore_files)
1167    restore_events = files_to_events[restore_file]
1168    restore2_events = files_to_events[restore2_file]
1169    self.assertLen(restore_events, 5)
1170    self.assertEqual(1, restore_events[1].step)
1171    self.assertEqual(2, restore_events[2].step)
1172    self.assertEqual(3, restore_events[3].step)
1173    self.assertEqual(4, restore_events[4].step)
1174    self.assertLen(restore2_events, 3)
1175    self.assertEqual(5, restore2_events[1].step)
1176    self.assertEqual(6, restore2_events[2].step)
1177
1178
1179class NoopWriterTest(test_util.TensorFlowTestCase):
1180
1181  def testNoopWriter_doesNothing(self):
1182    logdir = self.get_temp_dir()
1183    with context.eager_mode():
1184      writer = summary_ops.create_noop_writer()
1185      writer.init()
1186      with writer.as_default():
1187        result = summary_ops.write('test', 1.0, step=0)
1188      writer.flush()
1189      writer.close()
1190    self.assertFalse(result)  # Should have found no active writer
1191    files = gfile.Glob(os.path.join(logdir, '*'))
1192    self.assertLen(files, 0)
1193
1194  def testNoopWriter_asNestedContext_isTransparent(self):
1195    logdir = self.get_temp_dir()
1196    with context.eager_mode():
1197      writer = summary_ops.create_file_writer_v2(logdir)
1198      noop_writer = summary_ops.create_noop_writer()
1199      with writer.as_default():
1200        result1 = summary_ops.write('first', 1.0, step=0)
1201        with noop_writer.as_default():
1202          result2 = summary_ops.write('second', 1.0, step=0)
1203        result3 = summary_ops.write('third', 1.0, step=0)
1204    # All ops should have written, including the one inside the no-op writer,
1205    # since it doesn't actively *disable* writing - it just behaves as if that
1206    # entire `with` block wasn't there at all.
1207    self.assertAllEqual([result1, result2, result3], [True, True, True])
1208
1209  def testNoopWriter_setAsDefault(self):
1210    try:
1211      with context.eager_mode():
1212        writer = summary_ops.create_noop_writer()
1213        writer.set_as_default()
1214        result = summary_ops.write('test', 1.0, step=0)
1215      self.assertFalse(result)  # Should have found no active writer
1216    finally:
1217      # Ensure we clean up no matter how the test executes.
1218      summary_ops._summary_state.writer = None  # pylint: disable=protected-access
1219
1220
1221class SummaryOpsTest(test_util.TensorFlowTestCase):
1222
1223  def tearDown(self):
1224    summary_ops.trace_off()
1225    super().tearDown()
1226
1227  def exec_summary_op(self, summary_op_fn):
1228    assert context.executing_eagerly()
1229    logdir = self.get_temp_dir()
1230    writer = summary_ops.create_file_writer_v2(logdir)
1231    with writer.as_default():
1232      summary_op_fn()
1233    writer.close()
1234    events = events_from_logdir(logdir)
1235    return events[1]
1236
1237  def run_metadata(self, *args, **kwargs):
1238    assert context.executing_eagerly()
1239    logdir = self.get_temp_dir()
1240    writer = summary_ops.create_file_writer_v2(logdir)
1241    with writer.as_default():
1242      summary_ops.run_metadata(*args, **kwargs)
1243    writer.close()
1244    events = events_from_logdir(logdir)
1245    return events[1]
1246
1247  def run_metadata_graphs(self, *args, **kwargs):
1248    assert context.executing_eagerly()
1249    logdir = self.get_temp_dir()
1250    writer = summary_ops.create_file_writer_v2(logdir)
1251    with writer.as_default():
1252      summary_ops.run_metadata_graphs(*args, **kwargs)
1253    writer.close()
1254    events = events_from_logdir(logdir)
1255    return events[1]
1256
1257  def create_run_metadata(self):
1258    step_stats = step_stats_pb2.StepStats(dev_stats=[
1259        step_stats_pb2.DeviceStepStats(
1260            device='cpu:0',
1261            node_stats=[step_stats_pb2.NodeExecStats(node_name='hello')])
1262    ])
1263    return config_pb2.RunMetadata(
1264        function_graphs=[
1265            config_pb2.RunMetadata.FunctionGraphs(
1266                pre_optimization_graph=graph_pb2.GraphDef(
1267                    node=[node_def_pb2.NodeDef(name='foo')]))
1268        ],
1269        step_stats=step_stats)
1270
1271  def run_trace(self, f, step=1):
1272    assert context.executing_eagerly()
1273    logdir = self.get_temp_dir()
1274    writer = summary_ops.create_file_writer_v2(logdir)
1275    summary_ops.trace_on(graph=True, profiler=False)
1276    with writer.as_default():
1277      f()
1278      summary_ops.trace_export(name='foo', step=step)
1279    writer.close()
1280    events = events_from_logdir(logdir)
1281    return events[1]
1282
1283  @test_util.run_v2_only
1284  def testRunMetadata_usesNameAsTag(self):
1285    meta = config_pb2.RunMetadata()
1286
1287    with ops.name_scope('foo', skip_on_eager=False):
1288      event = self.run_metadata(name='my_name', data=meta, step=1)
1289      first_val = event.summary.value[0]
1290
1291    self.assertEqual('foo/my_name', first_val.tag)
1292
1293  @test_util.run_v2_only
1294  def testRunMetadata_summaryMetadata(self):
1295    expected_summary_metadata = """
1296      plugin_data {
1297        plugin_name: "graph_run_metadata"
1298        content: "1"
1299      }
1300    """
1301    meta = config_pb2.RunMetadata()
1302    event = self.run_metadata(name='my_name', data=meta, step=1)
1303    actual_summary_metadata = event.summary.value[0].metadata
1304    self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata)
1305
1306  @test_util.run_v2_only
1307  def testRunMetadata_wholeRunMetadata(self):
1308    expected_run_metadata = """
1309      step_stats {
1310        dev_stats {
1311          device: "cpu:0"
1312          node_stats {
1313            node_name: "hello"
1314          }
1315        }
1316      }
1317      function_graphs {
1318        pre_optimization_graph {
1319          node {
1320            name: "foo"
1321          }
1322        }
1323      }
1324    """
1325    meta = self.create_run_metadata()
1326    event = self.run_metadata(name='my_name', data=meta, step=1)
1327    first_val = event.summary.value[0]
1328
1329    actual_run_metadata = config_pb2.RunMetadata.FromString(
1330        first_val.tensor.string_val[0])
1331    self.assertProtoEquals(expected_run_metadata, actual_run_metadata)
1332
1333  @test_util.run_v2_only
1334  def testRunMetadata_usesDefaultStep(self):
1335    meta = config_pb2.RunMetadata()
1336    try:
1337      summary_ops.set_step(42)
1338      event = self.run_metadata(name='my_name', data=meta)
1339      self.assertEqual(42, event.step)
1340    finally:
1341      # Reset to default state for other tests.
1342      summary_ops.set_step(None)
1343
1344  @test_util.run_v2_only
1345  def testRunMetadataGraph_usesNameAsTag(self):
1346    meta = config_pb2.RunMetadata()
1347
1348    with ops.name_scope('foo', skip_on_eager=False):
1349      event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
1350      first_val = event.summary.value[0]
1351
1352    self.assertEqual('foo/my_name', first_val.tag)
1353
1354  @test_util.run_v2_only
1355  def testRunMetadataGraph_summaryMetadata(self):
1356    expected_summary_metadata = """
1357      plugin_data {
1358        plugin_name: "graph_run_metadata_graph"
1359        content: "1"
1360      }
1361    """
1362    meta = config_pb2.RunMetadata()
1363    event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
1364    actual_summary_metadata = event.summary.value[0].metadata
1365    self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata)
1366
1367  @test_util.run_v2_only
1368  def testRunMetadataGraph_runMetadataFragment(self):
1369    expected_run_metadata = """
1370      function_graphs {
1371        pre_optimization_graph {
1372          node {
1373            name: "foo"
1374          }
1375        }
1376      }
1377    """
1378    meta = self.create_run_metadata()
1379
1380    event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
1381    first_val = event.summary.value[0]
1382
1383    actual_run_metadata = config_pb2.RunMetadata.FromString(
1384        first_val.tensor.string_val[0])
1385    self.assertProtoEquals(expected_run_metadata, actual_run_metadata)
1386
1387  @test_util.run_v2_only
1388  def testRunMetadataGraph_usesDefaultStep(self):
1389    meta = config_pb2.RunMetadata()
1390    try:
1391      summary_ops.set_step(42)
1392      event = self.run_metadata_graphs(name='my_name', data=meta)
1393      self.assertEqual(42, event.step)
1394    finally:
1395      # Reset to default state for other tests.
1396      summary_ops.set_step(None)
1397
1398  @test_util.run_v2_only
1399  def testTrace(self):
1400
1401    @def_function.function
1402    def f():
1403      x = constant_op.constant(2)
1404      y = constant_op.constant(3)
1405      return x**y
1406
1407    event = self.run_trace(f)
1408
1409    first_val = event.summary.value[0]
1410    actual_run_metadata = config_pb2.RunMetadata.FromString(
1411        first_val.tensor.string_val[0])
1412
1413    # Content of function_graphs is large and, for instance, device can change.
1414    self.assertTrue(hasattr(actual_run_metadata, 'function_graphs'))
1415
1416  @test_util.run_v2_only
1417  def testTrace_cannotEnableTraceInFunction(self):
1418
1419    @def_function.function
1420    def f():
1421      summary_ops.trace_on(graph=True, profiler=False)
1422      x = constant_op.constant(2)
1423      y = constant_op.constant(3)
1424      return x**y
1425
1426    with test.mock.patch.object(logging, 'warn') as mock_log:
1427      f()
1428      self.assertRegex(
1429          str(mock_log.call_args), 'Cannot enable trace inside a tf.function.')
1430
1431  @test_util.run_v2_only
1432  def testTrace_cannotEnableTraceInGraphMode(self):
1433    with test.mock.patch.object(logging, 'warn') as mock_log:
1434      with context.graph_mode():
1435        summary_ops.trace_on(graph=True, profiler=False)
1436      self.assertRegex(
1437          str(mock_log.call_args), 'Must enable trace in eager mode.')
1438
1439  @test_util.run_v2_only
1440  def testTrace_cannotExportTraceWithoutTrace(self):
1441    with self.assertRaisesRegex(ValueError, 'Must enable trace before export.'):
1442      summary_ops.trace_export(name='foo', step=1)
1443
1444  @test_util.run_v2_only
1445  def testTrace_cannotExportTraceInFunction(self):
1446    summary_ops.trace_on(graph=True, profiler=False)
1447
1448    @def_function.function
1449    def f():
1450      x = constant_op.constant(2)
1451      y = constant_op.constant(3)
1452      summary_ops.trace_export(name='foo', step=1)
1453      return x**y
1454
1455    with test.mock.patch.object(logging, 'warn') as mock_log:
1456      f()
1457      self.assertRegex(
1458          str(mock_log.call_args), 'Cannot export trace inside a tf.function.')
1459
1460  @test_util.run_v2_only
1461  def testTrace_cannotExportTraceInGraphMode(self):
1462    with test.mock.patch.object(logging, 'warn') as mock_log:
1463      with context.graph_mode():
1464        summary_ops.trace_export(name='foo', step=1)
1465      self.assertRegex(
1466          str(mock_log.call_args),
1467          'Can only export trace while executing eagerly.')
1468
1469  @test_util.run_v2_only
1470  def testTrace_usesDefaultStep(self):
1471
1472    @def_function.function
1473    def f():
1474      x = constant_op.constant(2)
1475      y = constant_op.constant(3)
1476      return x**y
1477
1478    try:
1479      summary_ops.set_step(42)
1480      event = self.run_trace(f, step=None)
1481      self.assertEqual(42, event.step)
1482    finally:
1483      # Reset to default state for other tests.
1484      summary_ops.set_step(None)
1485
1486  @test_util.run_v2_only
1487  def testTrace_withProfiler(self):
1488
1489    @def_function.function
1490    def f():
1491      x = constant_op.constant(2)
1492      y = constant_op.constant(3)
1493      return x**y
1494
1495    assert context.executing_eagerly()
1496    logdir = self.get_temp_dir()
1497    writer = summary_ops.create_file_writer_v2(logdir)
1498    summary_ops.trace_on(graph=True, profiler=True)
1499    profiler_outdir = self.get_temp_dir()
1500    with writer.as_default():
1501      f()
1502      summary_ops.trace_export(
1503          name='foo', step=1, profiler_outdir=profiler_outdir)
1504    writer.close()
1505
1506  @test_util.run_v2_only
1507  def testGraph_graph(self):
1508
1509    @def_function.function
1510    def f():
1511      x = constant_op.constant(2)
1512      y = constant_op.constant(3)
1513      return x**y
1514
1515    def summary_op_fn():
1516      summary_ops.graph(f.get_concrete_function().graph)
1517
1518    event = self.exec_summary_op(summary_op_fn)
1519    self.assertIsNotNone(event.graph_def)
1520
1521  @test_util.run_v2_only
1522  def testGraph_graphDef(self):
1523
1524    @def_function.function
1525    def f():
1526      x = constant_op.constant(2)
1527      y = constant_op.constant(3)
1528      return x**y
1529
1530    def summary_op_fn():
1531      summary_ops.graph(f.get_concrete_function().graph.as_graph_def())
1532
1533    event = self.exec_summary_op(summary_op_fn)
1534    self.assertIsNotNone(event.graph_def)
1535
1536  @test_util.run_v2_only
1537  def testGraph_invalidData(self):
1538    def summary_op_fn():
1539      summary_ops.graph('hello')
1540
1541    with self.assertRaisesRegex(
1542        ValueError,
1543        r'\'graph_data\' is not tf.Graph or tf.compat.v1.GraphDef',
1544    ):
1545      self.exec_summary_op(summary_op_fn)
1546
1547  @test_util.run_v2_only
1548  def testGraph_fromGraphMode(self):
1549
1550    @def_function.function
1551    def f():
1552      x = constant_op.constant(2)
1553      y = constant_op.constant(3)
1554      return x**y
1555
1556    @def_function.function
1557    def g(graph):
1558      summary_ops.graph(graph)
1559
1560    def summary_op_fn():
1561      graph_def = f.get_concrete_function().graph.as_graph_def(add_shapes=True)
1562      func_graph = constant_op.constant(graph_def.SerializeToString())
1563      g(func_graph)
1564
1565    with self.assertRaisesRegex(
1566        ValueError,
1567        r'graph\(\) cannot be invoked inside a graph context.',
1568    ):
1569      self.exec_summary_op(summary_op_fn)
1570
1571
1572def events_from_file(filepath):
1573  """Returns all events in a single event file.
1574
1575  Args:
1576    filepath: Path to the event file.
1577
1578  Returns:
1579    A list of all tf.Event protos in the event file.
1580  """
1581  records = list(tf_record.tf_record_iterator(filepath))
1582  result = []
1583  for r in records:
1584    event = event_pb2.Event()
1585    event.ParseFromString(r)
1586    result.append(event)
1587  return result
1588
1589
1590def events_from_logdir(logdir):
1591  """Returns all events in the single eventfile in logdir.
1592
1593  Args:
1594    logdir: The directory in which the single event file is sought.
1595
1596  Returns:
1597    A list of all tf.Event protos from the single event file.
1598
1599  Raises:
1600    AssertionError: If logdir does not contain exactly one file.
1601  """
1602  assert gfile.Exists(logdir)
1603  files = gfile.ListDirectory(logdir)
1604  assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
1605  return events_from_file(os.path.join(logdir, files[0]))
1606
1607
1608def events_from_multifile_logdir(logdir):
1609  """Returns map of filename to events for all `tfevents` files in the logdir.
1610
1611  Args:
1612    logdir: The directory from which to load events.
1613
1614  Returns:
1615    A dict mapping from relative filenames to lists of tf.Event protos.
1616
1617  Raises:
1618    AssertionError: If logdir does not contain exactly one file.
1619  """
1620  assert gfile.Exists(logdir)
1621  files = [file for file in gfile.ListDirectory(logdir) if 'tfevents' in file]
1622  return {file: events_from_file(os.path.join(logdir, file)) for file in files}
1623
1624
1625def to_numpy(summary_value):
1626  return tensor_util.MakeNdarray(summary_value.tensor)
1627
1628
1629if __name__ == '__main__':
1630  test.main()
1631