• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for V2 summary ops from summary_ops_v2."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import unittest
23
24import six
25
26from tensorflow.core.framework import graph_pb2
27from tensorflow.core.framework import node_def_pb2
28from tensorflow.core.framework import step_stats_pb2
29from tensorflow.core.framework import summary_pb2
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.util import event_pb2
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.framework import test_util
41from tensorflow.python.lib.io import tf_record
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import summary_ops_v2 as summary_ops
44from tensorflow.python.ops import variables
45from tensorflow.python.platform import gfile
46from tensorflow.python.platform import test
47from tensorflow.python.platform import tf_logging as logging
48
49
50class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
51
52  def testWrite(self):
53    logdir = self.get_temp_dir()
54    with context.eager_mode():
55      with summary_ops.create_file_writer_v2(logdir).as_default():
56        output = summary_ops.write('tag', 42, step=12)
57        self.assertTrue(output.numpy())
58    events = events_from_logdir(logdir)
59    self.assertEqual(2, len(events))
60    self.assertEqual(12, events[1].step)
61    value = events[1].summary.value[0]
62    self.assertEqual('tag', value.tag)
63    self.assertEqual(42, to_numpy(value))
64
65  def testWrite_fromFunction(self):
66    logdir = self.get_temp_dir()
67    with context.eager_mode():
68      writer = summary_ops.create_file_writer_v2(logdir)
69      @def_function.function
70      def f():
71        with writer.as_default():
72          return summary_ops.write('tag', 42, step=12)
73      output = f()
74      self.assertTrue(output.numpy())
75    events = events_from_logdir(logdir)
76    self.assertEqual(2, len(events))
77    self.assertEqual(12, events[1].step)
78    value = events[1].summary.value[0]
79    self.assertEqual('tag', value.tag)
80    self.assertEqual(42, to_numpy(value))
81
82  def testWrite_metadata(self):
83    logdir = self.get_temp_dir()
84    metadata = summary_pb2.SummaryMetadata()
85    metadata.plugin_data.plugin_name = 'foo'
86    with context.eager_mode():
87      with summary_ops.create_file_writer_v2(logdir).as_default():
88        summary_ops.write('obj', 0, 0, metadata=metadata)
89        summary_ops.write('bytes', 0, 0, metadata=metadata.SerializeToString())
90        m = constant_op.constant(metadata.SerializeToString())
91        summary_ops.write('string_tensor', 0, 0, metadata=m)
92    events = events_from_logdir(logdir)
93    self.assertEqual(4, len(events))
94    self.assertEqual(metadata, events[1].summary.value[0].metadata)
95    self.assertEqual(metadata, events[2].summary.value[0].metadata)
96    self.assertEqual(metadata, events[3].summary.value[0].metadata)
97
98  def testWrite_name(self):
99    @def_function.function
100    def f():
101      output = summary_ops.write('tag', 42, step=12, name='anonymous')
102      self.assertTrue(output.name.startswith('anonymous'))
103    f()
104
105  def testWrite_ndarray(self):
106    logdir = self.get_temp_dir()
107    with context.eager_mode():
108      with summary_ops.create_file_writer_v2(logdir).as_default():
109        summary_ops.write('tag', [[1, 2], [3, 4]], step=12)
110    events = events_from_logdir(logdir)
111    value = events[1].summary.value[0]
112    self.assertAllEqual([[1, 2], [3, 4]], to_numpy(value))
113
114  def testWrite_tensor(self):
115    logdir = self.get_temp_dir()
116    with context.eager_mode():
117      t = constant_op.constant([[1, 2], [3, 4]])
118      with summary_ops.create_file_writer_v2(logdir).as_default():
119        summary_ops.write('tag', t, step=12)
120      expected = t.numpy()
121    events = events_from_logdir(logdir)
122    value = events[1].summary.value[0]
123    self.assertAllEqual(expected, to_numpy(value))
124
125  def testWrite_tensor_fromFunction(self):
126    logdir = self.get_temp_dir()
127    with context.eager_mode():
128      writer = summary_ops.create_file_writer_v2(logdir)
129      @def_function.function
130      def f(t):
131        with writer.as_default():
132          summary_ops.write('tag', t, step=12)
133      t = constant_op.constant([[1, 2], [3, 4]])
134      f(t)
135      expected = t.numpy()
136    events = events_from_logdir(logdir)
137    value = events[1].summary.value[0]
138    self.assertAllEqual(expected, to_numpy(value))
139
140  def testWrite_stringTensor(self):
141    logdir = self.get_temp_dir()
142    with context.eager_mode():
143      with summary_ops.create_file_writer_v2(logdir).as_default():
144        summary_ops.write('tag', [b'foo', b'bar'], step=12)
145    events = events_from_logdir(logdir)
146    value = events[1].summary.value[0]
147    self.assertAllEqual([b'foo', b'bar'], to_numpy(value))
148
149  @test_util.run_gpu_only
150  def testWrite_gpuDeviceContext(self):
151    logdir = self.get_temp_dir()
152    with context.eager_mode():
153      with summary_ops.create_file_writer(logdir).as_default():
154        with ops.device('/GPU:0'):
155          value = constant_op.constant(42.0)
156          step = constant_op.constant(12, dtype=dtypes.int64)
157          summary_ops.write('tag', value, step=step).numpy()
158    empty_metadata = summary_pb2.SummaryMetadata()
159    events = events_from_logdir(logdir)
160    self.assertEqual(2, len(events))
161    self.assertEqual(12, events[1].step)
162    self.assertEqual(42, to_numpy(events[1].summary.value[0]))
163    self.assertEqual(empty_metadata, events[1].summary.value[0].metadata)
164
165  @test_util.also_run_as_tf_function
166  def testWrite_noDefaultWriter(self):
167    # Use assertAllEqual instead of assertFalse since it works in a defun.
168    self.assertAllEqual(False, summary_ops.write('tag', 42, step=0))
169
170  @test_util.also_run_as_tf_function
171  def testWrite_noStep_okayIfAlsoNoDefaultWriter(self):
172    # Use assertAllEqual instead of assertFalse since it works in a defun.
173    self.assertAllEqual(False, summary_ops.write('tag', 42))
174
175  @test_util.also_run_as_tf_function
176  def testWrite_noStep(self):
177    logdir = self.get_temp_dir()
178    with summary_ops.create_file_writer(logdir).as_default():
179      with self.assertRaisesRegex(ValueError, 'No step set'):
180        summary_ops.write('tag', 42)
181
182  @test_util.also_run_as_tf_function
183  def testWrite_noStep_okayIfNotRecordingSummaries(self):
184    logdir = self.get_temp_dir()
185    with summary_ops.create_file_writer(logdir).as_default():
186      with summary_ops.record_if(False):
187        # Use assertAllEqual instead of assertFalse since it works in a defun.
188        self.assertAllEqual(False, summary_ops.write('tag', 42))
189
190  def testWrite_usingDefaultStep(self):
191    logdir = self.get_temp_dir()
192    try:
193      with context.eager_mode():
194        with summary_ops.create_file_writer(logdir).as_default():
195          summary_ops.set_step(1)
196          summary_ops.write('tag', 1.0)
197          summary_ops.set_step(2)
198          summary_ops.write('tag', 1.0)
199          mystep = variables.Variable(10, dtype=dtypes.int64)
200          summary_ops.set_step(mystep)
201          summary_ops.write('tag', 1.0)
202          mystep.assign_add(1)
203          summary_ops.write('tag', 1.0)
204      events = events_from_logdir(logdir)
205      self.assertEqual(5, len(events))
206      self.assertEqual(1, events[1].step)
207      self.assertEqual(2, events[2].step)
208      self.assertEqual(10, events[3].step)
209      self.assertEqual(11, events[4].step)
210    finally:
211      # Reset to default state for other tests.
212      summary_ops.set_step(None)
213
214  def testWrite_usingDefaultStepConstant_fromFunction(self):
215    logdir = self.get_temp_dir()
216    try:
217      with context.eager_mode():
218        writer = summary_ops.create_file_writer(logdir)
219        @def_function.function
220        def f():
221          with writer.as_default():
222            summary_ops.write('tag', 1.0)
223        summary_ops.set_step(1)
224        f()
225        summary_ops.set_step(2)
226        f()
227      events = events_from_logdir(logdir)
228      self.assertEqual(3, len(events))
229      self.assertEqual(1, events[1].step)
230      # The step value will still be 1 because the value was captured at the
231      # time the function was first traced.
232      self.assertEqual(1, events[2].step)
233    finally:
234      # Reset to default state for other tests.
235      summary_ops.set_step(None)
236
237  def testWrite_usingDefaultStepVariable_fromFunction(self):
238    logdir = self.get_temp_dir()
239    try:
240      with context.eager_mode():
241        writer = summary_ops.create_file_writer(logdir)
242        @def_function.function
243        def f():
244          with writer.as_default():
245            summary_ops.write('tag', 1.0)
246        mystep = variables.Variable(0, dtype=dtypes.int64)
247        summary_ops.set_step(mystep)
248        f()
249        mystep.assign_add(1)
250        f()
251        mystep.assign(10)
252        f()
253      events = events_from_logdir(logdir)
254      self.assertEqual(4, len(events))
255      self.assertEqual(0, events[1].step)
256      self.assertEqual(1, events[2].step)
257      self.assertEqual(10, events[3].step)
258    finally:
259      # Reset to default state for other tests.
260      summary_ops.set_step(None)
261
262  def testWrite_usingDefaultStepConstant_fromLegacyGraph(self):
263    logdir = self.get_temp_dir()
264    try:
265      with context.graph_mode():
266        writer = summary_ops.create_file_writer(logdir)
267        summary_ops.set_step(1)
268        with writer.as_default():
269          write_op = summary_ops.write('tag', 1.0)
270        summary_ops.set_step(2)
271        with self.cached_session() as sess:
272          sess.run(writer.init())
273          sess.run(write_op)
274          sess.run(write_op)
275          sess.run(writer.flush())
276      events = events_from_logdir(logdir)
277      self.assertEqual(3, len(events))
278      self.assertEqual(1, events[1].step)
279      # The step value will still be 1 because the value was captured at the
280      # time the graph was constructed.
281      self.assertEqual(1, events[2].step)
282    finally:
283      # Reset to default state for other tests.
284      summary_ops.set_step(None)
285
286  def testWrite_usingDefaultStepVariable_fromLegacyGraph(self):
287    logdir = self.get_temp_dir()
288    try:
289      with context.graph_mode():
290        writer = summary_ops.create_file_writer(logdir)
291        mystep = variables.Variable(0, dtype=dtypes.int64)
292        summary_ops.set_step(mystep)
293        with writer.as_default():
294          write_op = summary_ops.write('tag', 1.0)
295        first_assign_op = mystep.assign_add(1)
296        second_assign_op = mystep.assign(10)
297        with self.cached_session() as sess:
298          sess.run(writer.init())
299          sess.run(mystep.initializer)
300          sess.run(write_op)
301          sess.run(first_assign_op)
302          sess.run(write_op)
303          sess.run(second_assign_op)
304          sess.run(write_op)
305          sess.run(writer.flush())
306      events = events_from_logdir(logdir)
307      self.assertEqual(4, len(events))
308      self.assertEqual(0, events[1].step)
309      self.assertEqual(1, events[2].step)
310      self.assertEqual(10, events[3].step)
311    finally:
312      # Reset to default state for other tests.
313      summary_ops.set_step(None)
314
315  def testWrite_usingDefaultStep_fromAsDefault(self):
316    logdir = self.get_temp_dir()
317    try:
318      with context.eager_mode():
319        writer = summary_ops.create_file_writer(logdir)
320        with writer.as_default(step=1):
321          summary_ops.write('tag', 1.0)
322          with writer.as_default():
323            summary_ops.write('tag', 1.0)
324            with writer.as_default(step=2):
325              summary_ops.write('tag', 1.0)
326            summary_ops.write('tag', 1.0)
327            summary_ops.set_step(3)
328          summary_ops.write('tag', 1.0)
329      events = events_from_logdir(logdir)
330      self.assertListEqual([1, 1, 2, 1, 3], [e.step for e in events[1:]])
331    finally:
332      # Reset to default state for other tests.
333      summary_ops.set_step(None)
334
335  def testWrite_usingDefaultStepVariable_fromAsDefault(self):
336    logdir = self.get_temp_dir()
337    try:
338      with context.eager_mode():
339        writer = summary_ops.create_file_writer(logdir)
340        mystep = variables.Variable(1, dtype=dtypes.int64)
341        with writer.as_default(step=mystep):
342          summary_ops.write('tag', 1.0)
343          with writer.as_default():
344            mystep.assign(2)
345            summary_ops.write('tag', 1.0)
346            with writer.as_default(step=3):
347              summary_ops.write('tag', 1.0)
348            summary_ops.write('tag', 1.0)
349            mystep.assign(4)
350          summary_ops.write('tag', 1.0)
351      events = events_from_logdir(logdir)
352      self.assertListEqual([1, 2, 3, 2, 4], [e.step for e in events[1:]])
353    finally:
354      # Reset to default state for other tests.
355      summary_ops.set_step(None)
356
357  def testWrite_usingDefaultStep_fromSetAsDefault(self):
358    logdir = self.get_temp_dir()
359    try:
360      with context.eager_mode():
361        writer = summary_ops.create_file_writer(logdir)
362        mystep = variables.Variable(1, dtype=dtypes.int64)
363        writer.set_as_default(step=mystep)
364        summary_ops.write('tag', 1.0)
365        mystep.assign(2)
366        summary_ops.write('tag', 1.0)
367        writer.set_as_default(step=3)
368        summary_ops.write('tag', 1.0)
369        writer.flush()
370      events = events_from_logdir(logdir)
371      self.assertListEqual([1, 2, 3], [e.step for e in events[1:]])
372    finally:
373      # Reset to default state for other tests.
374      summary_ops.set_step(None)
375
376  def testWrite_usingDefaultStepVariable_fromSetAsDefault(self):
377    logdir = self.get_temp_dir()
378    try:
379      with context.eager_mode():
380        writer = summary_ops.create_file_writer(logdir)
381        writer.set_as_default(step=1)
382        summary_ops.write('tag', 1.0)
383        writer.set_as_default(step=2)
384        summary_ops.write('tag', 1.0)
385        writer.set_as_default()
386        summary_ops.write('tag', 1.0)
387        writer.flush()
388      events = events_from_logdir(logdir)
389      self.assertListEqual([1, 2, 2], [e.step for e in events[1:]])
390    finally:
391      # Reset to default state for other tests.
392      summary_ops.set_step(None)
393
394  def testWrite_recordIf_constant(self):
395    logdir = self.get_temp_dir()
396    with context.eager_mode():
397      with summary_ops.create_file_writer_v2(logdir).as_default():
398        self.assertTrue(summary_ops.write('default', 1, step=0))
399        with summary_ops.record_if(True):
400          self.assertTrue(summary_ops.write('set_on', 1, step=0))
401        with summary_ops.record_if(False):
402          self.assertFalse(summary_ops.write('set_off', 1, step=0))
403    events = events_from_logdir(logdir)
404    self.assertEqual(3, len(events))
405    self.assertEqual('default', events[1].summary.value[0].tag)
406    self.assertEqual('set_on', events[2].summary.value[0].tag)
407
408  def testWrite_recordIf_constant_fromFunction(self):
409    logdir = self.get_temp_dir()
410    with context.eager_mode():
411      writer = summary_ops.create_file_writer_v2(logdir)
412      @def_function.function
413      def f():
414        with writer.as_default():
415          # Use assertAllEqual instead of assertTrue since it works in a defun.
416          self.assertAllEqual(summary_ops.write('default', 1, step=0), True)
417          with summary_ops.record_if(True):
418            self.assertAllEqual(summary_ops.write('set_on', 1, step=0), True)
419          with summary_ops.record_if(False):
420            self.assertAllEqual(summary_ops.write('set_off', 1, step=0), False)
421      f()
422    events = events_from_logdir(logdir)
423    self.assertEqual(3, len(events))
424    self.assertEqual('default', events[1].summary.value[0].tag)
425    self.assertEqual('set_on', events[2].summary.value[0].tag)
426
427  def testWrite_recordIf_callable(self):
428    logdir = self.get_temp_dir()
429    with context.eager_mode():
430      step = variables.Variable(-1, dtype=dtypes.int64)
431      def record_fn():
432        step.assign_add(1)
433        return int(step % 2) == 0
434      with summary_ops.create_file_writer_v2(logdir).as_default():
435        with summary_ops.record_if(record_fn):
436          self.assertTrue(summary_ops.write('tag', 1, step=step))
437          self.assertFalse(summary_ops.write('tag', 1, step=step))
438          self.assertTrue(summary_ops.write('tag', 1, step=step))
439          self.assertFalse(summary_ops.write('tag', 1, step=step))
440          self.assertTrue(summary_ops.write('tag', 1, step=step))
441    events = events_from_logdir(logdir)
442    self.assertEqual(4, len(events))
443    self.assertEqual(0, events[1].step)
444    self.assertEqual(2, events[2].step)
445    self.assertEqual(4, events[3].step)
446
447  def testWrite_recordIf_callable_fromFunction(self):
448    logdir = self.get_temp_dir()
449    with context.eager_mode():
450      writer = summary_ops.create_file_writer_v2(logdir)
451      step = variables.Variable(-1, dtype=dtypes.int64)
452      @def_function.function
453      def record_fn():
454        step.assign_add(1)
455        return math_ops.equal(step % 2, 0)
456      @def_function.function
457      def f():
458        with writer.as_default():
459          with summary_ops.record_if(record_fn):
460            return [
461                summary_ops.write('tag', 1, step=step),
462                summary_ops.write('tag', 1, step=step),
463                summary_ops.write('tag', 1, step=step)]
464      self.assertAllEqual(f(), [True, False, True])
465      self.assertAllEqual(f(), [False, True, False])
466    events = events_from_logdir(logdir)
467    self.assertEqual(4, len(events))
468    self.assertEqual(0, events[1].step)
469    self.assertEqual(2, events[2].step)
470    self.assertEqual(4, events[3].step)
471
472  def testWrite_recordIf_tensorInput_fromFunction(self):
473    logdir = self.get_temp_dir()
474    with context.eager_mode():
475      writer = summary_ops.create_file_writer_v2(logdir)
476      @def_function.function(input_signature=[
477          tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)])
478      def f(step):
479        with writer.as_default():
480          with summary_ops.record_if(math_ops.equal(step % 2, 0)):
481            return summary_ops.write('tag', 1, step=step)
482      self.assertTrue(f(0))
483      self.assertFalse(f(1))
484      self.assertTrue(f(2))
485      self.assertFalse(f(3))
486      self.assertTrue(f(4))
487    events = events_from_logdir(logdir)
488    self.assertEqual(4, len(events))
489    self.assertEqual(0, events[1].step)
490    self.assertEqual(2, events[2].step)
491    self.assertEqual(4, events[3].step)
492
493  def testWriteRawPb(self):
494    logdir = self.get_temp_dir()
495    pb = summary_pb2.Summary()
496    pb.value.add().simple_value = 42.0
497    with context.eager_mode():
498      with summary_ops.create_file_writer_v2(logdir).as_default():
499        output = summary_ops.write_raw_pb(pb.SerializeToString(), step=12)
500        self.assertTrue(output.numpy())
501    events = events_from_logdir(logdir)
502    self.assertEqual(2, len(events))
503    self.assertEqual(12, events[1].step)
504    self.assertProtoEquals(pb, events[1].summary)
505
506  def testWriteRawPb_fromFunction(self):
507    logdir = self.get_temp_dir()
508    pb = summary_pb2.Summary()
509    pb.value.add().simple_value = 42.0
510    with context.eager_mode():
511      writer = summary_ops.create_file_writer_v2(logdir)
512      @def_function.function
513      def f():
514        with writer.as_default():
515          return summary_ops.write_raw_pb(pb.SerializeToString(), step=12)
516      output = f()
517      self.assertTrue(output.numpy())
518    events = events_from_logdir(logdir)
519    self.assertEqual(2, len(events))
520    self.assertEqual(12, events[1].step)
521    self.assertProtoEquals(pb, events[1].summary)
522
523  def testWriteRawPb_multipleValues(self):
524    logdir = self.get_temp_dir()
525    pb1 = summary_pb2.Summary()
526    pb1.value.add().simple_value = 1.0
527    pb1.value.add().simple_value = 2.0
528    pb2 = summary_pb2.Summary()
529    pb2.value.add().simple_value = 3.0
530    pb3 = summary_pb2.Summary()
531    pb3.value.add().simple_value = 4.0
532    pb3.value.add().simple_value = 5.0
533    pb3.value.add().simple_value = 6.0
534    pbs = [pb.SerializeToString() for pb in (pb1, pb2, pb3)]
535    with context.eager_mode():
536      with summary_ops.create_file_writer_v2(logdir).as_default():
537        output = summary_ops.write_raw_pb(pbs, step=12)
538        self.assertTrue(output.numpy())
539    events = events_from_logdir(logdir)
540    self.assertEqual(2, len(events))
541    self.assertEqual(12, events[1].step)
542    expected_pb = summary_pb2.Summary()
543    for i in range(6):
544      expected_pb.value.add().simple_value = i + 1.0
545    self.assertProtoEquals(expected_pb, events[1].summary)
546
547  def testWriteRawPb_invalidValue(self):
548    logdir = self.get_temp_dir()
549    with context.eager_mode():
550      with summary_ops.create_file_writer_v2(logdir).as_default():
551        with self.assertRaisesRegex(
552            errors.DataLossError,
553            'Bad tf.compat.v1.Summary binary proto tensor string'):
554          summary_ops.write_raw_pb('notaproto', step=12)
555
556  @test_util.also_run_as_tf_function
557  def testGetSetStep(self):
558    try:
559      self.assertIsNone(summary_ops.get_step())
560      summary_ops.set_step(1)
561      # Use assertAllEqual instead of assertEqual since it works in a defun.
562      self.assertAllEqual(1, summary_ops.get_step())
563      summary_ops.set_step(constant_op.constant(2))
564      self.assertAllEqual(2, summary_ops.get_step())
565    finally:
566      # Reset to default state for other tests.
567      summary_ops.set_step(None)
568
569  def testGetSetStep_variable(self):
570    with context.eager_mode():
571      try:
572        mystep = variables.Variable(0)
573        summary_ops.set_step(mystep)
574        self.assertAllEqual(0, summary_ops.get_step().read_value())
575        mystep.assign_add(1)
576        self.assertAllEqual(1, summary_ops.get_step().read_value())
577        # Check that set_step() properly maintains reference to variable.
578        del mystep
579        self.assertAllEqual(1, summary_ops.get_step().read_value())
580        summary_ops.get_step().assign_add(1)
581        self.assertAllEqual(2, summary_ops.get_step().read_value())
582      finally:
583        # Reset to default state for other tests.
584        summary_ops.set_step(None)
585
586  def testGetSetStep_variable_fromFunction(self):
587    with context.eager_mode():
588      try:
589        @def_function.function
590        def set_step(step):
591          summary_ops.set_step(step)
592          return summary_ops.get_step()
593        @def_function.function
594        def get_and_increment():
595          summary_ops.get_step().assign_add(1)
596          return summary_ops.get_step()
597        mystep = variables.Variable(0)
598        self.assertAllEqual(0, set_step(mystep))
599        self.assertAllEqual(0, summary_ops.get_step().read_value())
600        self.assertAllEqual(1, get_and_increment())
601        self.assertAllEqual(2, get_and_increment())
602        # Check that set_step() properly maintains reference to variable.
603        del mystep
604        self.assertAllEqual(3, get_and_increment())
605      finally:
606        # Reset to default state for other tests.
607        summary_ops.set_step(None)
608
609  @test_util.also_run_as_tf_function
610  def testSummaryScope(self):
611    with summary_ops.summary_scope('foo') as (tag, scope):
612      self.assertEqual('foo', tag)
613      self.assertEqual('foo/', scope)
614      with summary_ops.summary_scope('bar') as (tag, scope):
615        self.assertEqual('foo/bar', tag)
616        self.assertEqual('foo/bar/', scope)
617      with summary_ops.summary_scope('with/slash') as (tag, scope):
618        self.assertEqual('foo/with/slash', tag)
619        self.assertEqual('foo/with/slash/', scope)
620      with ops.name_scope(None, skip_on_eager=False):
621        with summary_ops.summary_scope('unnested') as (tag, scope):
622          self.assertEqual('unnested', tag)
623          self.assertEqual('unnested/', scope)
624
625  @test_util.also_run_as_tf_function
626  def testSummaryScope_defaultName(self):
627    with summary_ops.summary_scope(None) as (tag, scope):
628      self.assertEqual('summary', tag)
629      self.assertEqual('summary/', scope)
630    with summary_ops.summary_scope(None, 'backup') as (tag, scope):
631      self.assertEqual('backup', tag)
632      self.assertEqual('backup/', scope)
633
634  @test_util.also_run_as_tf_function
635  def testSummaryScope_handlesCharactersIllegalForScope(self):
636    with summary_ops.summary_scope('f?o?o') as (tag, scope):
637      self.assertEqual('f?o?o', tag)
638      self.assertEqual('foo/', scope)
639    # If all characters aren't legal for a scope name, use default name.
640    with summary_ops.summary_scope('???', 'backup') as (tag, scope):
641      self.assertEqual('???', tag)
642      self.assertEqual('backup/', scope)
643
644  @test_util.also_run_as_tf_function
645  def testSummaryScope_nameNotUniquifiedForTag(self):
646    constant_op.constant(0, name='foo')
647    with summary_ops.summary_scope('foo') as (tag, _):
648      self.assertEqual('foo', tag)
649    with summary_ops.summary_scope('foo') as (tag, _):
650      self.assertEqual('foo', tag)
651    with ops.name_scope('with', skip_on_eager=False):
652      constant_op.constant(0, name='slash')
653    with summary_ops.summary_scope('with/slash') as (tag, _):
654      self.assertEqual('with/slash', tag)
655
656  def testAllV2SummaryOps(self):
657    logdir = self.get_temp_dir()
658    def define_ops():
659      result = []
660      # TF 2.0 summary ops
661      result.append(summary_ops.write('write', 1, step=0))
662      result.append(summary_ops.write_raw_pb(b'', step=0, name='raw_pb'))
663      # TF 1.x tf.contrib.summary ops
664      result.append(summary_ops.generic('tensor', 1, step=1))
665      result.append(summary_ops.scalar('scalar', 2.0, step=1))
666      result.append(summary_ops.histogram('histogram', [1.0], step=1))
667      result.append(summary_ops.image('image', [[[[1.0]]]], step=1))
668      result.append(summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1))
669      return result
670    with context.graph_mode():
671      ops_without_writer = define_ops()
672      with summary_ops.create_file_writer_v2(logdir).as_default():
673        with summary_ops.record_if(True):
674          ops_recording_on = define_ops()
675        with summary_ops.record_if(False):
676          ops_recording_off = define_ops()
677      # We should be collecting all ops defined with a default writer present,
678      # regardless of whether recording was set on or off, but not those defined
679      # without a writer at all.
680      del ops_without_writer
681      expected_ops = ops_recording_on + ops_recording_off
682      self.assertCountEqual(expected_ops, summary_ops.all_v2_summary_ops())
683
684
685class SummaryWriterTest(test_util.TensorFlowTestCase):
686
687  def testCreate_withInitAndClose(self):
688    logdir = self.get_temp_dir()
689    with context.eager_mode():
690      writer = summary_ops.create_file_writer_v2(
691          logdir, max_queue=1000, flush_millis=1000000)
692      get_total = lambda: len(events_from_logdir(logdir))
693      self.assertEqual(1, get_total())  # file_version Event
694      # Calling init() again while writer is open has no effect
695      writer.init()
696      self.assertEqual(1, get_total())
697      with writer.as_default():
698        summary_ops.write('tag', 1, step=0)
699        self.assertEqual(1, get_total())
700        # Calling .close() should do an implicit flush
701        writer.close()
702        self.assertEqual(2, get_total())
703
704  def testCreate_fromFunction(self):
705    logdir = self.get_temp_dir()
706    @def_function.function
707    def f():
708      # Returned SummaryWriter must be stored in a non-local variable so it
709      # lives throughout the function execution.
710      if not hasattr(f, 'writer'):
711        f.writer = summary_ops.create_file_writer_v2(logdir)
712    with context.eager_mode():
713      f()
714    event_files = gfile.Glob(os.path.join(logdir, '*'))
715    self.assertEqual(1, len(event_files))
716
717  def testCreate_graphTensorArgument_raisesError(self):
718    logdir = self.get_temp_dir()
719    with context.graph_mode():
720      logdir_tensor = constant_op.constant(logdir)
721    with context.eager_mode():
722      with self.assertRaisesRegex(
723          ValueError, 'Invalid graph Tensor argument.*logdir'):
724        summary_ops.create_file_writer_v2(logdir_tensor)
725    self.assertEmpty(gfile.Glob(os.path.join(logdir, '*')))
726
727  def testCreate_fromFunction_graphTensorArgument_raisesError(self):
728    logdir = self.get_temp_dir()
729    @def_function.function
730    def f():
731      summary_ops.create_file_writer_v2(constant_op.constant(logdir))
732    with context.eager_mode():
733      with self.assertRaisesRegex(
734          ValueError, 'Invalid graph Tensor argument.*logdir'):
735        f()
736    self.assertEmpty(gfile.Glob(os.path.join(logdir, '*')))
737
738  def testCreate_fromFunction_unpersistedResource_raisesError(self):
739    logdir = self.get_temp_dir()
740    @def_function.function
741    def f():
742      with summary_ops.create_file_writer_v2(logdir).as_default():
743        pass  # Calling .as_default() is enough to indicate use.
744    with context.eager_mode():
745      # TODO(nickfelt): change this to a better error
746      with self.assertRaisesRegex(
747          errors.NotFoundError, 'Resource.*does not exist'):
748        f()
749    # Even though we didn't use it, an event file will have been created.
750    self.assertEqual(1, len(gfile.Glob(os.path.join(logdir, '*'))))
751
752  def testCreate_immediateSetAsDefault_retainsReference(self):
753    logdir = self.get_temp_dir()
754    try:
755      with context.eager_mode():
756        summary_ops.create_file_writer_v2(logdir).set_as_default()
757        summary_ops.flush()
758    finally:
759      # Ensure we clean up no matter how the test executes.
760      summary_ops._summary_state.writer = None  # pylint: disable=protected-access
761
762  def testCreate_immediateAsDefault_retainsReference(self):
763    logdir = self.get_temp_dir()
764    with context.eager_mode():
765      with summary_ops.create_file_writer_v2(logdir).as_default():
766        summary_ops.flush()
767
768  def testNoSharing(self):
769    # Two writers with the same logdir should not share state.
770    logdir = self.get_temp_dir()
771    with context.eager_mode():
772      writer1 = summary_ops.create_file_writer_v2(logdir)
773      with writer1.as_default():
774        summary_ops.write('tag', 1, step=1)
775      event_files = gfile.Glob(os.path.join(logdir, '*'))
776      self.assertEqual(1, len(event_files))
777      file1 = event_files[0]
778
779      writer2 = summary_ops.create_file_writer_v2(logdir)
780      with writer2.as_default():
781        summary_ops.write('tag', 1, step=2)
782      event_files = gfile.Glob(os.path.join(logdir, '*'))
783      self.assertEqual(2, len(event_files))
784      event_files.remove(file1)
785      file2 = event_files[0]
786
787      # Extra writes to ensure interleaved usage works.
788      with writer1.as_default():
789        summary_ops.write('tag', 1, step=1)
790      with writer2.as_default():
791        summary_ops.write('tag', 1, step=2)
792
793    events = iter(events_from_file(file1))
794    self.assertEqual('brain.Event:2', next(events).file_version)
795    self.assertEqual(1, next(events).step)
796    self.assertEqual(1, next(events).step)
797    self.assertRaises(StopIteration, lambda: next(events))
798    events = iter(events_from_file(file2))
799    self.assertEqual('brain.Event:2', next(events).file_version)
800    self.assertEqual(2, next(events).step)
801    self.assertEqual(2, next(events).step)
802    self.assertRaises(StopIteration, lambda: next(events))
803
804  def testNoSharing_fromFunction(self):
805    logdir = self.get_temp_dir()
806    @def_function.function
807    def f1():
808      if not hasattr(f1, 'writer'):
809        f1.writer = summary_ops.create_file_writer_v2(logdir)
810      with f1.writer.as_default():
811        summary_ops.write('tag', 1, step=1)
812    @def_function.function
813    def f2():
814      if not hasattr(f2, 'writer'):
815        f2.writer = summary_ops.create_file_writer_v2(logdir)
816      with f2.writer.as_default():
817        summary_ops.write('tag', 1, step=2)
818    with context.eager_mode():
819      f1()
820      event_files = gfile.Glob(os.path.join(logdir, '*'))
821      self.assertEqual(1, len(event_files))
822      file1 = event_files[0]
823
824      f2()
825      event_files = gfile.Glob(os.path.join(logdir, '*'))
826      self.assertEqual(2, len(event_files))
827      event_files.remove(file1)
828      file2 = event_files[0]
829
830      # Extra writes to ensure interleaved usage works.
831      f1()
832      f2()
833
834    events = iter(events_from_file(file1))
835    self.assertEqual('brain.Event:2', next(events).file_version)
836    self.assertEqual(1, next(events).step)
837    self.assertEqual(1, next(events).step)
838    self.assertRaises(StopIteration, lambda: next(events))
839    events = iter(events_from_file(file2))
840    self.assertEqual('brain.Event:2', next(events).file_version)
841    self.assertEqual(2, next(events).step)
842    self.assertEqual(2, next(events).step)
843    self.assertRaises(StopIteration, lambda: next(events))
844
845  def testMaxQueue(self):
846    logdir = self.get_temp_dir()
847    with context.eager_mode():
848      with summary_ops.create_file_writer_v2(
849          logdir, max_queue=1, flush_millis=999999).as_default():
850        get_total = lambda: len(events_from_logdir(logdir))
851        # Note: First tf.compat.v1.Event is always file_version.
852        self.assertEqual(1, get_total())
853        summary_ops.write('tag', 1, step=0)
854        self.assertEqual(1, get_total())
855        # Should flush after second summary since max_queue = 1
856        summary_ops.write('tag', 1, step=0)
857        self.assertEqual(3, get_total())
858
859  def testWriterFlush(self):
860    logdir = self.get_temp_dir()
861    get_total = lambda: len(events_from_logdir(logdir))
862    with context.eager_mode():
863      writer = summary_ops.create_file_writer_v2(
864          logdir, max_queue=1000, flush_millis=1000000)
865      self.assertEqual(1, get_total())  # file_version Event
866      with writer.as_default():
867        summary_ops.write('tag', 1, step=0)
868        self.assertEqual(1, get_total())
869        writer.flush()
870        self.assertEqual(2, get_total())
871        summary_ops.write('tag', 1, step=0)
872        self.assertEqual(2, get_total())
873      # Exiting the "as_default()" should do an implicit flush
874      self.assertEqual(3, get_total())
875
876  def testFlushFunction(self):
877    logdir = self.get_temp_dir()
878    with context.eager_mode():
879      writer = summary_ops.create_file_writer_v2(
880          logdir, max_queue=999999, flush_millis=999999)
881      with writer.as_default():
882        get_total = lambda: len(events_from_logdir(logdir))
883        # Note: First tf.compat.v1.Event is always file_version.
884        self.assertEqual(1, get_total())
885        summary_ops.write('tag', 1, step=0)
886        summary_ops.write('tag', 1, step=0)
887        self.assertEqual(1, get_total())
888        summary_ops.flush()
889        self.assertEqual(3, get_total())
890        # Test "writer" parameter
891        summary_ops.write('tag', 1, step=0)
892        self.assertEqual(3, get_total())
893        summary_ops.flush(writer=writer)
894        self.assertEqual(4, get_total())
895        summary_ops.write('tag', 1, step=0)
896        self.assertEqual(4, get_total())
897        summary_ops.flush(writer=writer._resource)  # pylint:disable=protected-access
898        self.assertEqual(5, get_total())
899
900  @test_util.assert_no_new_tensors
901  def testNoMemoryLeak_graphMode(self):
902    logdir = self.get_temp_dir()
903    with context.graph_mode(), ops.Graph().as_default():
904      summary_ops.create_file_writer_v2(logdir)
905
906  @test_util.assert_no_new_pyobjects_executing_eagerly
907  def testNoMemoryLeak_eagerMode(self):
908    logdir = self.get_temp_dir()
909    with summary_ops.create_file_writer_v2(logdir).as_default():
910      summary_ops.write('tag', 1, step=0)
911
912  def testClose_preventsLaterUse(self):
913    logdir = self.get_temp_dir()
914    with context.eager_mode():
915      writer = summary_ops.create_file_writer_v2(logdir)
916      writer.close()
917      writer.close()  # redundant close() is a no-op
918      writer.flush()  # redundant flush() is a no-op
919      with self.assertRaisesRegex(RuntimeError, 'already closed'):
920        writer.init()
921      with self.assertRaisesRegex(RuntimeError, 'already closed'):
922        with writer.as_default():
923          self.fail('should not get here')
924      with self.assertRaisesRegex(RuntimeError, 'already closed'):
925        writer.set_as_default()
926
927  def testClose_closesOpenFile(self):
928    try:
929      import psutil  # pylint: disable=g-import-not-at-top
930    except ImportError:
931      raise unittest.SkipTest('test requires psutil')
932    proc = psutil.Process()
933    get_open_filenames = lambda: set(info[0] for info in proc.open_files())
934    logdir = self.get_temp_dir()
935    with context.eager_mode():
936      writer = summary_ops.create_file_writer_v2(logdir)
937      files = gfile.Glob(os.path.join(logdir, '*'))
938      self.assertEqual(1, len(files))
939      eventfile = files[0]
940      self.assertIn(eventfile, get_open_filenames())
941      writer.close()
942      self.assertNotIn(eventfile, get_open_filenames())
943
944  def testDereference_closesOpenFile(self):
945    try:
946      import psutil  # pylint: disable=g-import-not-at-top
947    except ImportError:
948      raise unittest.SkipTest('test requires psutil')
949    proc = psutil.Process()
950    get_open_filenames = lambda: set(info[0] for info in proc.open_files())
951    logdir = self.get_temp_dir()
952    with context.eager_mode():
953      writer = summary_ops.create_file_writer_v2(logdir)
954      files = gfile.Glob(os.path.join(logdir, '*'))
955      self.assertEqual(1, len(files))
956      eventfile = files[0]
957      self.assertIn(eventfile, get_open_filenames())
958      del writer
959      self.assertNotIn(eventfile, get_open_filenames())
960
961
962class SummaryOpsTest(test_util.TensorFlowTestCase):
963
964  def tearDown(self):
965    summary_ops.trace_off()
966
967  def exec_summary_op(self, summary_op_fn):
968    assert context.executing_eagerly()
969    logdir = self.get_temp_dir()
970    writer = summary_ops.create_file_writer(logdir)
971    with writer.as_default():
972      summary_op_fn()
973    writer.close()
974    events = events_from_logdir(logdir)
975    return events[1]
976
977  def run_metadata(self, *args, **kwargs):
978    assert context.executing_eagerly()
979    logdir = self.get_temp_dir()
980    writer = summary_ops.create_file_writer(logdir)
981    with writer.as_default():
982      summary_ops.run_metadata(*args, **kwargs)
983    writer.close()
984    events = events_from_logdir(logdir)
985    return events[1]
986
987  def run_metadata_graphs(self, *args, **kwargs):
988    assert context.executing_eagerly()
989    logdir = self.get_temp_dir()
990    writer = summary_ops.create_file_writer(logdir)
991    with writer.as_default():
992      summary_ops.run_metadata_graphs(*args, **kwargs)
993    writer.close()
994    events = events_from_logdir(logdir)
995    return events[1]
996
997  def create_run_metadata(self):
998    step_stats = step_stats_pb2.StepStats(dev_stats=[
999        step_stats_pb2.DeviceStepStats(
1000            device='cpu:0',
1001            node_stats=[step_stats_pb2.NodeExecStats(node_name='hello')])
1002    ])
1003    return config_pb2.RunMetadata(
1004        function_graphs=[
1005            config_pb2.RunMetadata.FunctionGraphs(
1006                pre_optimization_graph=graph_pb2.GraphDef(
1007                    node=[node_def_pb2.NodeDef(name='foo')]))
1008        ],
1009        step_stats=step_stats)
1010
1011  def run_trace(self, f, step=1):
1012    assert context.executing_eagerly()
1013    logdir = self.get_temp_dir()
1014    writer = summary_ops.create_file_writer(logdir)
1015    summary_ops.trace_on(graph=True, profiler=False)
1016    with writer.as_default():
1017      f()
1018      summary_ops.trace_export(name='foo', step=step)
1019    writer.close()
1020    events = events_from_logdir(logdir)
1021    return events[1]
1022
1023  @test_util.run_v2_only
1024  def testRunMetadata_usesNameAsTag(self):
1025    meta = config_pb2.RunMetadata()
1026
1027    with ops.name_scope('foo', skip_on_eager=False):
1028      event = self.run_metadata(name='my_name', data=meta, step=1)
1029      first_val = event.summary.value[0]
1030
1031    self.assertEqual('foo/my_name', first_val.tag)
1032
1033  @test_util.run_v2_only
1034  def testRunMetadata_summaryMetadata(self):
1035    expected_summary_metadata = """
1036      plugin_data {
1037        plugin_name: "graph_run_metadata"
1038        content: "1"
1039      }
1040    """
1041    meta = config_pb2.RunMetadata()
1042    event = self.run_metadata(name='my_name', data=meta, step=1)
1043    actual_summary_metadata = event.summary.value[0].metadata
1044    self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata)
1045
1046  @test_util.run_v2_only
1047  def testRunMetadata_wholeRunMetadata(self):
1048    expected_run_metadata = """
1049      step_stats {
1050        dev_stats {
1051          device: "cpu:0"
1052          node_stats {
1053            node_name: "hello"
1054          }
1055        }
1056      }
1057      function_graphs {
1058        pre_optimization_graph {
1059          node {
1060            name: "foo"
1061          }
1062        }
1063      }
1064    """
1065    meta = self.create_run_metadata()
1066    event = self.run_metadata(name='my_name', data=meta, step=1)
1067    first_val = event.summary.value[0]
1068
1069    actual_run_metadata = config_pb2.RunMetadata.FromString(
1070        first_val.tensor.string_val[0])
1071    self.assertProtoEquals(expected_run_metadata, actual_run_metadata)
1072
1073  @test_util.run_v2_only
1074  def testRunMetadata_usesDefaultStep(self):
1075    meta = config_pb2.RunMetadata()
1076    try:
1077      summary_ops.set_step(42)
1078      event = self.run_metadata(name='my_name', data=meta)
1079      self.assertEqual(42, event.step)
1080    finally:
1081      # Reset to default state for other tests.
1082      summary_ops.set_step(None)
1083
1084  @test_util.run_v2_only
1085  def testRunMetadataGraph_usesNameAsTag(self):
1086    meta = config_pb2.RunMetadata()
1087
1088    with ops.name_scope('foo', skip_on_eager=False):
1089      event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
1090      first_val = event.summary.value[0]
1091
1092    self.assertEqual('foo/my_name', first_val.tag)
1093
1094  @test_util.run_v2_only
1095  def testRunMetadataGraph_summaryMetadata(self):
1096    expected_summary_metadata = """
1097      plugin_data {
1098        plugin_name: "graph_run_metadata_graph"
1099        content: "1"
1100      }
1101    """
1102    meta = config_pb2.RunMetadata()
1103    event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
1104    actual_summary_metadata = event.summary.value[0].metadata
1105    self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata)
1106
1107  @test_util.run_v2_only
1108  def testRunMetadataGraph_runMetadataFragment(self):
1109    expected_run_metadata = """
1110      function_graphs {
1111        pre_optimization_graph {
1112          node {
1113            name: "foo"
1114          }
1115        }
1116      }
1117    """
1118    meta = self.create_run_metadata()
1119
1120    event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
1121    first_val = event.summary.value[0]
1122
1123    actual_run_metadata = config_pb2.RunMetadata.FromString(
1124        first_val.tensor.string_val[0])
1125    self.assertProtoEquals(expected_run_metadata, actual_run_metadata)
1126
1127  @test_util.run_v2_only
1128  def testRunMetadataGraph_usesDefaultStep(self):
1129    meta = config_pb2.RunMetadata()
1130    try:
1131      summary_ops.set_step(42)
1132      event = self.run_metadata_graphs(name='my_name', data=meta)
1133      self.assertEqual(42, event.step)
1134    finally:
1135      # Reset to default state for other tests.
1136      summary_ops.set_step(None)
1137
1138  @test_util.run_v2_only
1139  def testTrace(self):
1140
1141    @def_function.function
1142    def f():
1143      x = constant_op.constant(2)
1144      y = constant_op.constant(3)
1145      return x**y
1146
1147    event = self.run_trace(f)
1148
1149    first_val = event.summary.value[0]
1150    actual_run_metadata = config_pb2.RunMetadata.FromString(
1151        first_val.tensor.string_val[0])
1152
1153    # Content of function_graphs is large and, for instance, device can change.
1154    self.assertTrue(hasattr(actual_run_metadata, 'function_graphs'))
1155
1156  @test_util.run_v2_only
1157  def testTrace_cannotEnableTraceInFunction(self):
1158
1159    @def_function.function
1160    def f():
1161      summary_ops.trace_on(graph=True, profiler=False)
1162      x = constant_op.constant(2)
1163      y = constant_op.constant(3)
1164      return x**y
1165
1166    with test.mock.patch.object(logging, 'warn') as mock_log:
1167      f()
1168      self.assertRegex(
1169          str(mock_log.call_args), 'Cannot enable trace inside a tf.function.')
1170
1171  @test_util.run_v2_only
1172  def testTrace_cannotEnableTraceInGraphMode(self):
1173    with test.mock.patch.object(logging, 'warn') as mock_log:
1174      with context.graph_mode():
1175        summary_ops.trace_on(graph=True, profiler=False)
1176      self.assertRegex(
1177          str(mock_log.call_args), 'Must enable trace in eager mode.')
1178
1179  @test_util.run_v2_only
1180  def testTrace_cannotExportTraceWithoutTrace(self):
1181    with six.assertRaisesRegex(self, ValueError,
1182                               'Must enable trace before export.'):
1183      summary_ops.trace_export(name='foo', step=1)
1184
1185  @test_util.run_v2_only
1186  def testTrace_cannotExportTraceInFunction(self):
1187    summary_ops.trace_on(graph=True, profiler=False)
1188
1189    @def_function.function
1190    def f():
1191      x = constant_op.constant(2)
1192      y = constant_op.constant(3)
1193      summary_ops.trace_export(name='foo', step=1)
1194      return x**y
1195
1196    with test.mock.patch.object(logging, 'warn') as mock_log:
1197      f()
1198      self.assertRegex(
1199          str(mock_log.call_args), 'Cannot export trace inside a tf.function.')
1200
1201  @test_util.run_v2_only
1202  def testTrace_cannotExportTraceInGraphMode(self):
1203    with test.mock.patch.object(logging, 'warn') as mock_log:
1204      with context.graph_mode():
1205        summary_ops.trace_export(name='foo', step=1)
1206      self.assertRegex(
1207          str(mock_log.call_args),
1208          'Can only export trace while executing eagerly.')
1209
1210  @test_util.run_v2_only
1211  def testTrace_usesDefaultStep(self):
1212
1213    @def_function.function
1214    def f():
1215      x = constant_op.constant(2)
1216      y = constant_op.constant(3)
1217      return x**y
1218
1219    try:
1220      summary_ops.set_step(42)
1221      event = self.run_trace(f, step=None)
1222      self.assertEqual(42, event.step)
1223    finally:
1224      # Reset to default state for other tests.
1225      summary_ops.set_step(None)
1226
1227  @test_util.run_v2_only
1228  def testTrace_withProfiler(self):
1229
1230    @def_function.function
1231    def f():
1232      x = constant_op.constant(2)
1233      y = constant_op.constant(3)
1234      return x**y
1235
1236    assert context.executing_eagerly()
1237    logdir = self.get_temp_dir()
1238    writer = summary_ops.create_file_writer(logdir)
1239    summary_ops.trace_on(graph=True, profiler=True)
1240    profiler_outdir = self.get_temp_dir()
1241    with writer.as_default():
1242      f()
1243      summary_ops.trace_export(
1244          name='foo', step=1, profiler_outdir=profiler_outdir)
1245    writer.close()
1246
1247  @test_util.run_v2_only
1248  def testGraph_graph(self):
1249
1250    @def_function.function
1251    def f():
1252      x = constant_op.constant(2)
1253      y = constant_op.constant(3)
1254      return x**y
1255
1256    def summary_op_fn():
1257      summary_ops.graph(f.get_concrete_function().graph)
1258
1259    event = self.exec_summary_op(summary_op_fn)
1260    self.assertIsNotNone(event.graph_def)
1261
1262  @test_util.run_v2_only
1263  def testGraph_graphDef(self):
1264
1265    @def_function.function
1266    def f():
1267      x = constant_op.constant(2)
1268      y = constant_op.constant(3)
1269      return x**y
1270
1271    def summary_op_fn():
1272      summary_ops.graph(f.get_concrete_function().graph.as_graph_def())
1273
1274    event = self.exec_summary_op(summary_op_fn)
1275    self.assertIsNotNone(event.graph_def)
1276
1277  @test_util.run_v2_only
1278  def testGraph_invalidData(self):
1279    def summary_op_fn():
1280      summary_ops.graph('hello')
1281
1282    with self.assertRaisesRegex(
1283        ValueError,
1284        r'\'graph_data\' is not tf.Graph or tf.compat.v1.GraphDef',
1285    ):
1286      self.exec_summary_op(summary_op_fn)
1287
1288  @test_util.run_v2_only
1289  def testGraph_fromGraphMode(self):
1290
1291    @def_function.function
1292    def f():
1293      x = constant_op.constant(2)
1294      y = constant_op.constant(3)
1295      return x**y
1296
1297    @def_function.function
1298    def g(graph):
1299      summary_ops.graph(graph)
1300
1301    def summary_op_fn():
1302      graph_def = f.get_concrete_function().graph.as_graph_def(add_shapes=True)
1303      func_graph = constant_op.constant(graph_def.SerializeToString())
1304      g(func_graph)
1305
1306    with self.assertRaisesRegex(
1307        ValueError,
1308        r'graph\(\) cannot be invoked inside a graph context.',
1309    ):
1310      self.exec_summary_op(summary_op_fn)
1311
1312
1313def events_from_file(filepath):
1314  """Returns all events in a single event file.
1315
1316  Args:
1317    filepath: Path to the event file.
1318
1319  Returns:
1320    A list of all tf.Event protos in the event file.
1321  """
1322  records = list(tf_record.tf_record_iterator(filepath))
1323  result = []
1324  for r in records:
1325    event = event_pb2.Event()
1326    event.ParseFromString(r)
1327    result.append(event)
1328  return result
1329
1330
1331def events_from_logdir(logdir):
1332  """Returns all events in the single eventfile in logdir.
1333
1334  Args:
1335    logdir: The directory in which the single event file is sought.
1336
1337  Returns:
1338    A list of all tf.Event protos from the single event file.
1339
1340  Raises:
1341    AssertionError: If logdir does not contain exactly one file.
1342  """
1343  assert gfile.Exists(logdir)
1344  files = gfile.ListDirectory(logdir)
1345  assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
1346  return events_from_file(os.path.join(logdir, files[0]))
1347
1348
1349def to_numpy(summary_value):
1350  return tensor_util.MakeNdarray(summary_value.tensor)
1351
1352
1353if __name__ == '__main__':
1354  test.main()
1355