1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for V2 summary ops from summary_ops_v2.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import unittest 23 24import six 25 26from tensorflow.core.framework import graph_pb2 27from tensorflow.core.framework import node_def_pb2 28from tensorflow.core.framework import step_stats_pb2 29from tensorflow.core.framework import summary_pb2 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.util import event_pb2 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import tensor_util 40from tensorflow.python.framework import test_util 41from tensorflow.python.lib.io import tf_record 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import summary_ops_v2 as summary_ops 44from tensorflow.python.ops import variables 45from tensorflow.python.platform import gfile 46from tensorflow.python.platform import test 47from tensorflow.python.platform import tf_logging as logging 48 49 50class SummaryOpsCoreTest(test_util.TensorFlowTestCase): 51 52 def testWrite(self): 53 logdir = self.get_temp_dir() 54 with context.eager_mode(): 55 with summary_ops.create_file_writer_v2(logdir).as_default(): 56 output = summary_ops.write('tag', 42, step=12) 57 self.assertTrue(output.numpy()) 58 events = events_from_logdir(logdir) 59 self.assertEqual(2, len(events)) 60 self.assertEqual(12, events[1].step) 61 value = events[1].summary.value[0] 62 self.assertEqual('tag', value.tag) 63 self.assertEqual(42, to_numpy(value)) 64 65 def testWrite_fromFunction(self): 66 logdir = self.get_temp_dir() 67 with context.eager_mode(): 68 writer = summary_ops.create_file_writer_v2(logdir) 69 @def_function.function 70 def f(): 71 with writer.as_default(): 72 return summary_ops.write('tag', 42, step=12) 73 output = f() 74 self.assertTrue(output.numpy()) 75 events = events_from_logdir(logdir) 76 self.assertEqual(2, len(events)) 77 self.assertEqual(12, events[1].step) 78 value = events[1].summary.value[0] 79 self.assertEqual('tag', value.tag) 80 self.assertEqual(42, to_numpy(value)) 81 82 def testWrite_metadata(self): 83 logdir = self.get_temp_dir() 84 metadata = summary_pb2.SummaryMetadata() 85 metadata.plugin_data.plugin_name = 'foo' 86 with context.eager_mode(): 87 with summary_ops.create_file_writer_v2(logdir).as_default(): 88 summary_ops.write('obj', 0, 0, metadata=metadata) 89 summary_ops.write('bytes', 0, 0, metadata=metadata.SerializeToString()) 90 m = constant_op.constant(metadata.SerializeToString()) 91 summary_ops.write('string_tensor', 0, 0, metadata=m) 92 events = events_from_logdir(logdir) 93 self.assertEqual(4, len(events)) 94 self.assertEqual(metadata, events[1].summary.value[0].metadata) 95 self.assertEqual(metadata, events[2].summary.value[0].metadata) 96 self.assertEqual(metadata, events[3].summary.value[0].metadata) 97 98 def testWrite_name(self): 99 @def_function.function 100 def f(): 101 output = summary_ops.write('tag', 42, step=12, name='anonymous') 102 self.assertTrue(output.name.startswith('anonymous')) 103 f() 104 105 def testWrite_ndarray(self): 106 logdir = self.get_temp_dir() 107 with context.eager_mode(): 108 with summary_ops.create_file_writer_v2(logdir).as_default(): 109 summary_ops.write('tag', [[1, 2], [3, 4]], step=12) 110 events = events_from_logdir(logdir) 111 value = events[1].summary.value[0] 112 self.assertAllEqual([[1, 2], [3, 4]], to_numpy(value)) 113 114 def testWrite_tensor(self): 115 logdir = self.get_temp_dir() 116 with context.eager_mode(): 117 t = constant_op.constant([[1, 2], [3, 4]]) 118 with summary_ops.create_file_writer_v2(logdir).as_default(): 119 summary_ops.write('tag', t, step=12) 120 expected = t.numpy() 121 events = events_from_logdir(logdir) 122 value = events[1].summary.value[0] 123 self.assertAllEqual(expected, to_numpy(value)) 124 125 def testWrite_tensor_fromFunction(self): 126 logdir = self.get_temp_dir() 127 with context.eager_mode(): 128 writer = summary_ops.create_file_writer_v2(logdir) 129 @def_function.function 130 def f(t): 131 with writer.as_default(): 132 summary_ops.write('tag', t, step=12) 133 t = constant_op.constant([[1, 2], [3, 4]]) 134 f(t) 135 expected = t.numpy() 136 events = events_from_logdir(logdir) 137 value = events[1].summary.value[0] 138 self.assertAllEqual(expected, to_numpy(value)) 139 140 def testWrite_stringTensor(self): 141 logdir = self.get_temp_dir() 142 with context.eager_mode(): 143 with summary_ops.create_file_writer_v2(logdir).as_default(): 144 summary_ops.write('tag', [b'foo', b'bar'], step=12) 145 events = events_from_logdir(logdir) 146 value = events[1].summary.value[0] 147 self.assertAllEqual([b'foo', b'bar'], to_numpy(value)) 148 149 @test_util.run_gpu_only 150 def testWrite_gpuDeviceContext(self): 151 logdir = self.get_temp_dir() 152 with context.eager_mode(): 153 with summary_ops.create_file_writer(logdir).as_default(): 154 with ops.device('/GPU:0'): 155 value = constant_op.constant(42.0) 156 step = constant_op.constant(12, dtype=dtypes.int64) 157 summary_ops.write('tag', value, step=step).numpy() 158 empty_metadata = summary_pb2.SummaryMetadata() 159 events = events_from_logdir(logdir) 160 self.assertEqual(2, len(events)) 161 self.assertEqual(12, events[1].step) 162 self.assertEqual(42, to_numpy(events[1].summary.value[0])) 163 self.assertEqual(empty_metadata, events[1].summary.value[0].metadata) 164 165 @test_util.also_run_as_tf_function 166 def testWrite_noDefaultWriter(self): 167 # Use assertAllEqual instead of assertFalse since it works in a defun. 168 self.assertAllEqual(False, summary_ops.write('tag', 42, step=0)) 169 170 @test_util.also_run_as_tf_function 171 def testWrite_noStep_okayIfAlsoNoDefaultWriter(self): 172 # Use assertAllEqual instead of assertFalse since it works in a defun. 173 self.assertAllEqual(False, summary_ops.write('tag', 42)) 174 175 @test_util.also_run_as_tf_function 176 def testWrite_noStep(self): 177 logdir = self.get_temp_dir() 178 with summary_ops.create_file_writer(logdir).as_default(): 179 with self.assertRaisesRegex(ValueError, 'No step set'): 180 summary_ops.write('tag', 42) 181 182 @test_util.also_run_as_tf_function 183 def testWrite_noStep_okayIfNotRecordingSummaries(self): 184 logdir = self.get_temp_dir() 185 with summary_ops.create_file_writer(logdir).as_default(): 186 with summary_ops.record_if(False): 187 # Use assertAllEqual instead of assertFalse since it works in a defun. 188 self.assertAllEqual(False, summary_ops.write('tag', 42)) 189 190 def testWrite_usingDefaultStep(self): 191 logdir = self.get_temp_dir() 192 try: 193 with context.eager_mode(): 194 with summary_ops.create_file_writer(logdir).as_default(): 195 summary_ops.set_step(1) 196 summary_ops.write('tag', 1.0) 197 summary_ops.set_step(2) 198 summary_ops.write('tag', 1.0) 199 mystep = variables.Variable(10, dtype=dtypes.int64) 200 summary_ops.set_step(mystep) 201 summary_ops.write('tag', 1.0) 202 mystep.assign_add(1) 203 summary_ops.write('tag', 1.0) 204 events = events_from_logdir(logdir) 205 self.assertEqual(5, len(events)) 206 self.assertEqual(1, events[1].step) 207 self.assertEqual(2, events[2].step) 208 self.assertEqual(10, events[3].step) 209 self.assertEqual(11, events[4].step) 210 finally: 211 # Reset to default state for other tests. 212 summary_ops.set_step(None) 213 214 def testWrite_usingDefaultStepConstant_fromFunction(self): 215 logdir = self.get_temp_dir() 216 try: 217 with context.eager_mode(): 218 writer = summary_ops.create_file_writer(logdir) 219 @def_function.function 220 def f(): 221 with writer.as_default(): 222 summary_ops.write('tag', 1.0) 223 summary_ops.set_step(1) 224 f() 225 summary_ops.set_step(2) 226 f() 227 events = events_from_logdir(logdir) 228 self.assertEqual(3, len(events)) 229 self.assertEqual(1, events[1].step) 230 # The step value will still be 1 because the value was captured at the 231 # time the function was first traced. 232 self.assertEqual(1, events[2].step) 233 finally: 234 # Reset to default state for other tests. 235 summary_ops.set_step(None) 236 237 def testWrite_usingDefaultStepVariable_fromFunction(self): 238 logdir = self.get_temp_dir() 239 try: 240 with context.eager_mode(): 241 writer = summary_ops.create_file_writer(logdir) 242 @def_function.function 243 def f(): 244 with writer.as_default(): 245 summary_ops.write('tag', 1.0) 246 mystep = variables.Variable(0, dtype=dtypes.int64) 247 summary_ops.set_step(mystep) 248 f() 249 mystep.assign_add(1) 250 f() 251 mystep.assign(10) 252 f() 253 events = events_from_logdir(logdir) 254 self.assertEqual(4, len(events)) 255 self.assertEqual(0, events[1].step) 256 self.assertEqual(1, events[2].step) 257 self.assertEqual(10, events[3].step) 258 finally: 259 # Reset to default state for other tests. 260 summary_ops.set_step(None) 261 262 def testWrite_usingDefaultStepConstant_fromLegacyGraph(self): 263 logdir = self.get_temp_dir() 264 try: 265 with context.graph_mode(): 266 writer = summary_ops.create_file_writer(logdir) 267 summary_ops.set_step(1) 268 with writer.as_default(): 269 write_op = summary_ops.write('tag', 1.0) 270 summary_ops.set_step(2) 271 with self.cached_session() as sess: 272 sess.run(writer.init()) 273 sess.run(write_op) 274 sess.run(write_op) 275 sess.run(writer.flush()) 276 events = events_from_logdir(logdir) 277 self.assertEqual(3, len(events)) 278 self.assertEqual(1, events[1].step) 279 # The step value will still be 1 because the value was captured at the 280 # time the graph was constructed. 281 self.assertEqual(1, events[2].step) 282 finally: 283 # Reset to default state for other tests. 284 summary_ops.set_step(None) 285 286 def testWrite_usingDefaultStepVariable_fromLegacyGraph(self): 287 logdir = self.get_temp_dir() 288 try: 289 with context.graph_mode(): 290 writer = summary_ops.create_file_writer(logdir) 291 mystep = variables.Variable(0, dtype=dtypes.int64) 292 summary_ops.set_step(mystep) 293 with writer.as_default(): 294 write_op = summary_ops.write('tag', 1.0) 295 first_assign_op = mystep.assign_add(1) 296 second_assign_op = mystep.assign(10) 297 with self.cached_session() as sess: 298 sess.run(writer.init()) 299 sess.run(mystep.initializer) 300 sess.run(write_op) 301 sess.run(first_assign_op) 302 sess.run(write_op) 303 sess.run(second_assign_op) 304 sess.run(write_op) 305 sess.run(writer.flush()) 306 events = events_from_logdir(logdir) 307 self.assertEqual(4, len(events)) 308 self.assertEqual(0, events[1].step) 309 self.assertEqual(1, events[2].step) 310 self.assertEqual(10, events[3].step) 311 finally: 312 # Reset to default state for other tests. 313 summary_ops.set_step(None) 314 315 def testWrite_usingDefaultStep_fromAsDefault(self): 316 logdir = self.get_temp_dir() 317 try: 318 with context.eager_mode(): 319 writer = summary_ops.create_file_writer(logdir) 320 with writer.as_default(step=1): 321 summary_ops.write('tag', 1.0) 322 with writer.as_default(): 323 summary_ops.write('tag', 1.0) 324 with writer.as_default(step=2): 325 summary_ops.write('tag', 1.0) 326 summary_ops.write('tag', 1.0) 327 summary_ops.set_step(3) 328 summary_ops.write('tag', 1.0) 329 events = events_from_logdir(logdir) 330 self.assertListEqual([1, 1, 2, 1, 3], [e.step for e in events[1:]]) 331 finally: 332 # Reset to default state for other tests. 333 summary_ops.set_step(None) 334 335 def testWrite_usingDefaultStepVariable_fromAsDefault(self): 336 logdir = self.get_temp_dir() 337 try: 338 with context.eager_mode(): 339 writer = summary_ops.create_file_writer(logdir) 340 mystep = variables.Variable(1, dtype=dtypes.int64) 341 with writer.as_default(step=mystep): 342 summary_ops.write('tag', 1.0) 343 with writer.as_default(): 344 mystep.assign(2) 345 summary_ops.write('tag', 1.0) 346 with writer.as_default(step=3): 347 summary_ops.write('tag', 1.0) 348 summary_ops.write('tag', 1.0) 349 mystep.assign(4) 350 summary_ops.write('tag', 1.0) 351 events = events_from_logdir(logdir) 352 self.assertListEqual([1, 2, 3, 2, 4], [e.step for e in events[1:]]) 353 finally: 354 # Reset to default state for other tests. 355 summary_ops.set_step(None) 356 357 def testWrite_usingDefaultStep_fromSetAsDefault(self): 358 logdir = self.get_temp_dir() 359 try: 360 with context.eager_mode(): 361 writer = summary_ops.create_file_writer(logdir) 362 mystep = variables.Variable(1, dtype=dtypes.int64) 363 writer.set_as_default(step=mystep) 364 summary_ops.write('tag', 1.0) 365 mystep.assign(2) 366 summary_ops.write('tag', 1.0) 367 writer.set_as_default(step=3) 368 summary_ops.write('tag', 1.0) 369 writer.flush() 370 events = events_from_logdir(logdir) 371 self.assertListEqual([1, 2, 3], [e.step for e in events[1:]]) 372 finally: 373 # Reset to default state for other tests. 374 summary_ops.set_step(None) 375 376 def testWrite_usingDefaultStepVariable_fromSetAsDefault(self): 377 logdir = self.get_temp_dir() 378 try: 379 with context.eager_mode(): 380 writer = summary_ops.create_file_writer(logdir) 381 writer.set_as_default(step=1) 382 summary_ops.write('tag', 1.0) 383 writer.set_as_default(step=2) 384 summary_ops.write('tag', 1.0) 385 writer.set_as_default() 386 summary_ops.write('tag', 1.0) 387 writer.flush() 388 events = events_from_logdir(logdir) 389 self.assertListEqual([1, 2, 2], [e.step for e in events[1:]]) 390 finally: 391 # Reset to default state for other tests. 392 summary_ops.set_step(None) 393 394 def testWrite_recordIf_constant(self): 395 logdir = self.get_temp_dir() 396 with context.eager_mode(): 397 with summary_ops.create_file_writer_v2(logdir).as_default(): 398 self.assertTrue(summary_ops.write('default', 1, step=0)) 399 with summary_ops.record_if(True): 400 self.assertTrue(summary_ops.write('set_on', 1, step=0)) 401 with summary_ops.record_if(False): 402 self.assertFalse(summary_ops.write('set_off', 1, step=0)) 403 events = events_from_logdir(logdir) 404 self.assertEqual(3, len(events)) 405 self.assertEqual('default', events[1].summary.value[0].tag) 406 self.assertEqual('set_on', events[2].summary.value[0].tag) 407 408 def testWrite_recordIf_constant_fromFunction(self): 409 logdir = self.get_temp_dir() 410 with context.eager_mode(): 411 writer = summary_ops.create_file_writer_v2(logdir) 412 @def_function.function 413 def f(): 414 with writer.as_default(): 415 # Use assertAllEqual instead of assertTrue since it works in a defun. 416 self.assertAllEqual(summary_ops.write('default', 1, step=0), True) 417 with summary_ops.record_if(True): 418 self.assertAllEqual(summary_ops.write('set_on', 1, step=0), True) 419 with summary_ops.record_if(False): 420 self.assertAllEqual(summary_ops.write('set_off', 1, step=0), False) 421 f() 422 events = events_from_logdir(logdir) 423 self.assertEqual(3, len(events)) 424 self.assertEqual('default', events[1].summary.value[0].tag) 425 self.assertEqual('set_on', events[2].summary.value[0].tag) 426 427 def testWrite_recordIf_callable(self): 428 logdir = self.get_temp_dir() 429 with context.eager_mode(): 430 step = variables.Variable(-1, dtype=dtypes.int64) 431 def record_fn(): 432 step.assign_add(1) 433 return int(step % 2) == 0 434 with summary_ops.create_file_writer_v2(logdir).as_default(): 435 with summary_ops.record_if(record_fn): 436 self.assertTrue(summary_ops.write('tag', 1, step=step)) 437 self.assertFalse(summary_ops.write('tag', 1, step=step)) 438 self.assertTrue(summary_ops.write('tag', 1, step=step)) 439 self.assertFalse(summary_ops.write('tag', 1, step=step)) 440 self.assertTrue(summary_ops.write('tag', 1, step=step)) 441 events = events_from_logdir(logdir) 442 self.assertEqual(4, len(events)) 443 self.assertEqual(0, events[1].step) 444 self.assertEqual(2, events[2].step) 445 self.assertEqual(4, events[3].step) 446 447 def testWrite_recordIf_callable_fromFunction(self): 448 logdir = self.get_temp_dir() 449 with context.eager_mode(): 450 writer = summary_ops.create_file_writer_v2(logdir) 451 step = variables.Variable(-1, dtype=dtypes.int64) 452 @def_function.function 453 def record_fn(): 454 step.assign_add(1) 455 return math_ops.equal(step % 2, 0) 456 @def_function.function 457 def f(): 458 with writer.as_default(): 459 with summary_ops.record_if(record_fn): 460 return [ 461 summary_ops.write('tag', 1, step=step), 462 summary_ops.write('tag', 1, step=step), 463 summary_ops.write('tag', 1, step=step)] 464 self.assertAllEqual(f(), [True, False, True]) 465 self.assertAllEqual(f(), [False, True, False]) 466 events = events_from_logdir(logdir) 467 self.assertEqual(4, len(events)) 468 self.assertEqual(0, events[1].step) 469 self.assertEqual(2, events[2].step) 470 self.assertEqual(4, events[3].step) 471 472 def testWrite_recordIf_tensorInput_fromFunction(self): 473 logdir = self.get_temp_dir() 474 with context.eager_mode(): 475 writer = summary_ops.create_file_writer_v2(logdir) 476 @def_function.function(input_signature=[ 477 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)]) 478 def f(step): 479 with writer.as_default(): 480 with summary_ops.record_if(math_ops.equal(step % 2, 0)): 481 return summary_ops.write('tag', 1, step=step) 482 self.assertTrue(f(0)) 483 self.assertFalse(f(1)) 484 self.assertTrue(f(2)) 485 self.assertFalse(f(3)) 486 self.assertTrue(f(4)) 487 events = events_from_logdir(logdir) 488 self.assertEqual(4, len(events)) 489 self.assertEqual(0, events[1].step) 490 self.assertEqual(2, events[2].step) 491 self.assertEqual(4, events[3].step) 492 493 def testWriteRawPb(self): 494 logdir = self.get_temp_dir() 495 pb = summary_pb2.Summary() 496 pb.value.add().simple_value = 42.0 497 with context.eager_mode(): 498 with summary_ops.create_file_writer_v2(logdir).as_default(): 499 output = summary_ops.write_raw_pb(pb.SerializeToString(), step=12) 500 self.assertTrue(output.numpy()) 501 events = events_from_logdir(logdir) 502 self.assertEqual(2, len(events)) 503 self.assertEqual(12, events[1].step) 504 self.assertProtoEquals(pb, events[1].summary) 505 506 def testWriteRawPb_fromFunction(self): 507 logdir = self.get_temp_dir() 508 pb = summary_pb2.Summary() 509 pb.value.add().simple_value = 42.0 510 with context.eager_mode(): 511 writer = summary_ops.create_file_writer_v2(logdir) 512 @def_function.function 513 def f(): 514 with writer.as_default(): 515 return summary_ops.write_raw_pb(pb.SerializeToString(), step=12) 516 output = f() 517 self.assertTrue(output.numpy()) 518 events = events_from_logdir(logdir) 519 self.assertEqual(2, len(events)) 520 self.assertEqual(12, events[1].step) 521 self.assertProtoEquals(pb, events[1].summary) 522 523 def testWriteRawPb_multipleValues(self): 524 logdir = self.get_temp_dir() 525 pb1 = summary_pb2.Summary() 526 pb1.value.add().simple_value = 1.0 527 pb1.value.add().simple_value = 2.0 528 pb2 = summary_pb2.Summary() 529 pb2.value.add().simple_value = 3.0 530 pb3 = summary_pb2.Summary() 531 pb3.value.add().simple_value = 4.0 532 pb3.value.add().simple_value = 5.0 533 pb3.value.add().simple_value = 6.0 534 pbs = [pb.SerializeToString() for pb in (pb1, pb2, pb3)] 535 with context.eager_mode(): 536 with summary_ops.create_file_writer_v2(logdir).as_default(): 537 output = summary_ops.write_raw_pb(pbs, step=12) 538 self.assertTrue(output.numpy()) 539 events = events_from_logdir(logdir) 540 self.assertEqual(2, len(events)) 541 self.assertEqual(12, events[1].step) 542 expected_pb = summary_pb2.Summary() 543 for i in range(6): 544 expected_pb.value.add().simple_value = i + 1.0 545 self.assertProtoEquals(expected_pb, events[1].summary) 546 547 def testWriteRawPb_invalidValue(self): 548 logdir = self.get_temp_dir() 549 with context.eager_mode(): 550 with summary_ops.create_file_writer_v2(logdir).as_default(): 551 with self.assertRaisesRegex( 552 errors.DataLossError, 553 'Bad tf.compat.v1.Summary binary proto tensor string'): 554 summary_ops.write_raw_pb('notaproto', step=12) 555 556 @test_util.also_run_as_tf_function 557 def testGetSetStep(self): 558 try: 559 self.assertIsNone(summary_ops.get_step()) 560 summary_ops.set_step(1) 561 # Use assertAllEqual instead of assertEqual since it works in a defun. 562 self.assertAllEqual(1, summary_ops.get_step()) 563 summary_ops.set_step(constant_op.constant(2)) 564 self.assertAllEqual(2, summary_ops.get_step()) 565 finally: 566 # Reset to default state for other tests. 567 summary_ops.set_step(None) 568 569 def testGetSetStep_variable(self): 570 with context.eager_mode(): 571 try: 572 mystep = variables.Variable(0) 573 summary_ops.set_step(mystep) 574 self.assertAllEqual(0, summary_ops.get_step().read_value()) 575 mystep.assign_add(1) 576 self.assertAllEqual(1, summary_ops.get_step().read_value()) 577 # Check that set_step() properly maintains reference to variable. 578 del mystep 579 self.assertAllEqual(1, summary_ops.get_step().read_value()) 580 summary_ops.get_step().assign_add(1) 581 self.assertAllEqual(2, summary_ops.get_step().read_value()) 582 finally: 583 # Reset to default state for other tests. 584 summary_ops.set_step(None) 585 586 def testGetSetStep_variable_fromFunction(self): 587 with context.eager_mode(): 588 try: 589 @def_function.function 590 def set_step(step): 591 summary_ops.set_step(step) 592 return summary_ops.get_step() 593 @def_function.function 594 def get_and_increment(): 595 summary_ops.get_step().assign_add(1) 596 return summary_ops.get_step() 597 mystep = variables.Variable(0) 598 self.assertAllEqual(0, set_step(mystep)) 599 self.assertAllEqual(0, summary_ops.get_step().read_value()) 600 self.assertAllEqual(1, get_and_increment()) 601 self.assertAllEqual(2, get_and_increment()) 602 # Check that set_step() properly maintains reference to variable. 603 del mystep 604 self.assertAllEqual(3, get_and_increment()) 605 finally: 606 # Reset to default state for other tests. 607 summary_ops.set_step(None) 608 609 @test_util.also_run_as_tf_function 610 def testSummaryScope(self): 611 with summary_ops.summary_scope('foo') as (tag, scope): 612 self.assertEqual('foo', tag) 613 self.assertEqual('foo/', scope) 614 with summary_ops.summary_scope('bar') as (tag, scope): 615 self.assertEqual('foo/bar', tag) 616 self.assertEqual('foo/bar/', scope) 617 with summary_ops.summary_scope('with/slash') as (tag, scope): 618 self.assertEqual('foo/with/slash', tag) 619 self.assertEqual('foo/with/slash/', scope) 620 with ops.name_scope(None, skip_on_eager=False): 621 with summary_ops.summary_scope('unnested') as (tag, scope): 622 self.assertEqual('unnested', tag) 623 self.assertEqual('unnested/', scope) 624 625 @test_util.also_run_as_tf_function 626 def testSummaryScope_defaultName(self): 627 with summary_ops.summary_scope(None) as (tag, scope): 628 self.assertEqual('summary', tag) 629 self.assertEqual('summary/', scope) 630 with summary_ops.summary_scope(None, 'backup') as (tag, scope): 631 self.assertEqual('backup', tag) 632 self.assertEqual('backup/', scope) 633 634 @test_util.also_run_as_tf_function 635 def testSummaryScope_handlesCharactersIllegalForScope(self): 636 with summary_ops.summary_scope('f?o?o') as (tag, scope): 637 self.assertEqual('f?o?o', tag) 638 self.assertEqual('foo/', scope) 639 # If all characters aren't legal for a scope name, use default name. 640 with summary_ops.summary_scope('???', 'backup') as (tag, scope): 641 self.assertEqual('???', tag) 642 self.assertEqual('backup/', scope) 643 644 @test_util.also_run_as_tf_function 645 def testSummaryScope_nameNotUniquifiedForTag(self): 646 constant_op.constant(0, name='foo') 647 with summary_ops.summary_scope('foo') as (tag, _): 648 self.assertEqual('foo', tag) 649 with summary_ops.summary_scope('foo') as (tag, _): 650 self.assertEqual('foo', tag) 651 with ops.name_scope('with', skip_on_eager=False): 652 constant_op.constant(0, name='slash') 653 with summary_ops.summary_scope('with/slash') as (tag, _): 654 self.assertEqual('with/slash', tag) 655 656 def testAllV2SummaryOps(self): 657 logdir = self.get_temp_dir() 658 def define_ops(): 659 result = [] 660 # TF 2.0 summary ops 661 result.append(summary_ops.write('write', 1, step=0)) 662 result.append(summary_ops.write_raw_pb(b'', step=0, name='raw_pb')) 663 # TF 1.x tf.contrib.summary ops 664 result.append(summary_ops.generic('tensor', 1, step=1)) 665 result.append(summary_ops.scalar('scalar', 2.0, step=1)) 666 result.append(summary_ops.histogram('histogram', [1.0], step=1)) 667 result.append(summary_ops.image('image', [[[[1.0]]]], step=1)) 668 result.append(summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1)) 669 return result 670 with context.graph_mode(): 671 ops_without_writer = define_ops() 672 with summary_ops.create_file_writer_v2(logdir).as_default(): 673 with summary_ops.record_if(True): 674 ops_recording_on = define_ops() 675 with summary_ops.record_if(False): 676 ops_recording_off = define_ops() 677 # We should be collecting all ops defined with a default writer present, 678 # regardless of whether recording was set on or off, but not those defined 679 # without a writer at all. 680 del ops_without_writer 681 expected_ops = ops_recording_on + ops_recording_off 682 self.assertCountEqual(expected_ops, summary_ops.all_v2_summary_ops()) 683 684 685class SummaryWriterTest(test_util.TensorFlowTestCase): 686 687 def testCreate_withInitAndClose(self): 688 logdir = self.get_temp_dir() 689 with context.eager_mode(): 690 writer = summary_ops.create_file_writer_v2( 691 logdir, max_queue=1000, flush_millis=1000000) 692 get_total = lambda: len(events_from_logdir(logdir)) 693 self.assertEqual(1, get_total()) # file_version Event 694 # Calling init() again while writer is open has no effect 695 writer.init() 696 self.assertEqual(1, get_total()) 697 with writer.as_default(): 698 summary_ops.write('tag', 1, step=0) 699 self.assertEqual(1, get_total()) 700 # Calling .close() should do an implicit flush 701 writer.close() 702 self.assertEqual(2, get_total()) 703 704 def testCreate_fromFunction(self): 705 logdir = self.get_temp_dir() 706 @def_function.function 707 def f(): 708 # Returned SummaryWriter must be stored in a non-local variable so it 709 # lives throughout the function execution. 710 if not hasattr(f, 'writer'): 711 f.writer = summary_ops.create_file_writer_v2(logdir) 712 with context.eager_mode(): 713 f() 714 event_files = gfile.Glob(os.path.join(logdir, '*')) 715 self.assertEqual(1, len(event_files)) 716 717 def testCreate_graphTensorArgument_raisesError(self): 718 logdir = self.get_temp_dir() 719 with context.graph_mode(): 720 logdir_tensor = constant_op.constant(logdir) 721 with context.eager_mode(): 722 with self.assertRaisesRegex( 723 ValueError, 'Invalid graph Tensor argument.*logdir'): 724 summary_ops.create_file_writer_v2(logdir_tensor) 725 self.assertEmpty(gfile.Glob(os.path.join(logdir, '*'))) 726 727 def testCreate_fromFunction_graphTensorArgument_raisesError(self): 728 logdir = self.get_temp_dir() 729 @def_function.function 730 def f(): 731 summary_ops.create_file_writer_v2(constant_op.constant(logdir)) 732 with context.eager_mode(): 733 with self.assertRaisesRegex( 734 ValueError, 'Invalid graph Tensor argument.*logdir'): 735 f() 736 self.assertEmpty(gfile.Glob(os.path.join(logdir, '*'))) 737 738 def testCreate_fromFunction_unpersistedResource_raisesError(self): 739 logdir = self.get_temp_dir() 740 @def_function.function 741 def f(): 742 with summary_ops.create_file_writer_v2(logdir).as_default(): 743 pass # Calling .as_default() is enough to indicate use. 744 with context.eager_mode(): 745 # TODO(nickfelt): change this to a better error 746 with self.assertRaisesRegex( 747 errors.NotFoundError, 'Resource.*does not exist'): 748 f() 749 # Even though we didn't use it, an event file will have been created. 750 self.assertEqual(1, len(gfile.Glob(os.path.join(logdir, '*')))) 751 752 def testCreate_immediateSetAsDefault_retainsReference(self): 753 logdir = self.get_temp_dir() 754 try: 755 with context.eager_mode(): 756 summary_ops.create_file_writer_v2(logdir).set_as_default() 757 summary_ops.flush() 758 finally: 759 # Ensure we clean up no matter how the test executes. 760 summary_ops._summary_state.writer = None # pylint: disable=protected-access 761 762 def testCreate_immediateAsDefault_retainsReference(self): 763 logdir = self.get_temp_dir() 764 with context.eager_mode(): 765 with summary_ops.create_file_writer_v2(logdir).as_default(): 766 summary_ops.flush() 767 768 def testNoSharing(self): 769 # Two writers with the same logdir should not share state. 770 logdir = self.get_temp_dir() 771 with context.eager_mode(): 772 writer1 = summary_ops.create_file_writer_v2(logdir) 773 with writer1.as_default(): 774 summary_ops.write('tag', 1, step=1) 775 event_files = gfile.Glob(os.path.join(logdir, '*')) 776 self.assertEqual(1, len(event_files)) 777 file1 = event_files[0] 778 779 writer2 = summary_ops.create_file_writer_v2(logdir) 780 with writer2.as_default(): 781 summary_ops.write('tag', 1, step=2) 782 event_files = gfile.Glob(os.path.join(logdir, '*')) 783 self.assertEqual(2, len(event_files)) 784 event_files.remove(file1) 785 file2 = event_files[0] 786 787 # Extra writes to ensure interleaved usage works. 788 with writer1.as_default(): 789 summary_ops.write('tag', 1, step=1) 790 with writer2.as_default(): 791 summary_ops.write('tag', 1, step=2) 792 793 events = iter(events_from_file(file1)) 794 self.assertEqual('brain.Event:2', next(events).file_version) 795 self.assertEqual(1, next(events).step) 796 self.assertEqual(1, next(events).step) 797 self.assertRaises(StopIteration, lambda: next(events)) 798 events = iter(events_from_file(file2)) 799 self.assertEqual('brain.Event:2', next(events).file_version) 800 self.assertEqual(2, next(events).step) 801 self.assertEqual(2, next(events).step) 802 self.assertRaises(StopIteration, lambda: next(events)) 803 804 def testNoSharing_fromFunction(self): 805 logdir = self.get_temp_dir() 806 @def_function.function 807 def f1(): 808 if not hasattr(f1, 'writer'): 809 f1.writer = summary_ops.create_file_writer_v2(logdir) 810 with f1.writer.as_default(): 811 summary_ops.write('tag', 1, step=1) 812 @def_function.function 813 def f2(): 814 if not hasattr(f2, 'writer'): 815 f2.writer = summary_ops.create_file_writer_v2(logdir) 816 with f2.writer.as_default(): 817 summary_ops.write('tag', 1, step=2) 818 with context.eager_mode(): 819 f1() 820 event_files = gfile.Glob(os.path.join(logdir, '*')) 821 self.assertEqual(1, len(event_files)) 822 file1 = event_files[0] 823 824 f2() 825 event_files = gfile.Glob(os.path.join(logdir, '*')) 826 self.assertEqual(2, len(event_files)) 827 event_files.remove(file1) 828 file2 = event_files[0] 829 830 # Extra writes to ensure interleaved usage works. 831 f1() 832 f2() 833 834 events = iter(events_from_file(file1)) 835 self.assertEqual('brain.Event:2', next(events).file_version) 836 self.assertEqual(1, next(events).step) 837 self.assertEqual(1, next(events).step) 838 self.assertRaises(StopIteration, lambda: next(events)) 839 events = iter(events_from_file(file2)) 840 self.assertEqual('brain.Event:2', next(events).file_version) 841 self.assertEqual(2, next(events).step) 842 self.assertEqual(2, next(events).step) 843 self.assertRaises(StopIteration, lambda: next(events)) 844 845 def testMaxQueue(self): 846 logdir = self.get_temp_dir() 847 with context.eager_mode(): 848 with summary_ops.create_file_writer_v2( 849 logdir, max_queue=1, flush_millis=999999).as_default(): 850 get_total = lambda: len(events_from_logdir(logdir)) 851 # Note: First tf.compat.v1.Event is always file_version. 852 self.assertEqual(1, get_total()) 853 summary_ops.write('tag', 1, step=0) 854 self.assertEqual(1, get_total()) 855 # Should flush after second summary since max_queue = 1 856 summary_ops.write('tag', 1, step=0) 857 self.assertEqual(3, get_total()) 858 859 def testWriterFlush(self): 860 logdir = self.get_temp_dir() 861 get_total = lambda: len(events_from_logdir(logdir)) 862 with context.eager_mode(): 863 writer = summary_ops.create_file_writer_v2( 864 logdir, max_queue=1000, flush_millis=1000000) 865 self.assertEqual(1, get_total()) # file_version Event 866 with writer.as_default(): 867 summary_ops.write('tag', 1, step=0) 868 self.assertEqual(1, get_total()) 869 writer.flush() 870 self.assertEqual(2, get_total()) 871 summary_ops.write('tag', 1, step=0) 872 self.assertEqual(2, get_total()) 873 # Exiting the "as_default()" should do an implicit flush 874 self.assertEqual(3, get_total()) 875 876 def testFlushFunction(self): 877 logdir = self.get_temp_dir() 878 with context.eager_mode(): 879 writer = summary_ops.create_file_writer_v2( 880 logdir, max_queue=999999, flush_millis=999999) 881 with writer.as_default(): 882 get_total = lambda: len(events_from_logdir(logdir)) 883 # Note: First tf.compat.v1.Event is always file_version. 884 self.assertEqual(1, get_total()) 885 summary_ops.write('tag', 1, step=0) 886 summary_ops.write('tag', 1, step=0) 887 self.assertEqual(1, get_total()) 888 summary_ops.flush() 889 self.assertEqual(3, get_total()) 890 # Test "writer" parameter 891 summary_ops.write('tag', 1, step=0) 892 self.assertEqual(3, get_total()) 893 summary_ops.flush(writer=writer) 894 self.assertEqual(4, get_total()) 895 summary_ops.write('tag', 1, step=0) 896 self.assertEqual(4, get_total()) 897 summary_ops.flush(writer=writer._resource) # pylint:disable=protected-access 898 self.assertEqual(5, get_total()) 899 900 @test_util.assert_no_new_tensors 901 def testNoMemoryLeak_graphMode(self): 902 logdir = self.get_temp_dir() 903 with context.graph_mode(), ops.Graph().as_default(): 904 summary_ops.create_file_writer_v2(logdir) 905 906 @test_util.assert_no_new_pyobjects_executing_eagerly 907 def testNoMemoryLeak_eagerMode(self): 908 logdir = self.get_temp_dir() 909 with summary_ops.create_file_writer_v2(logdir).as_default(): 910 summary_ops.write('tag', 1, step=0) 911 912 def testClose_preventsLaterUse(self): 913 logdir = self.get_temp_dir() 914 with context.eager_mode(): 915 writer = summary_ops.create_file_writer_v2(logdir) 916 writer.close() 917 writer.close() # redundant close() is a no-op 918 writer.flush() # redundant flush() is a no-op 919 with self.assertRaisesRegex(RuntimeError, 'already closed'): 920 writer.init() 921 with self.assertRaisesRegex(RuntimeError, 'already closed'): 922 with writer.as_default(): 923 self.fail('should not get here') 924 with self.assertRaisesRegex(RuntimeError, 'already closed'): 925 writer.set_as_default() 926 927 def testClose_closesOpenFile(self): 928 try: 929 import psutil # pylint: disable=g-import-not-at-top 930 except ImportError: 931 raise unittest.SkipTest('test requires psutil') 932 proc = psutil.Process() 933 get_open_filenames = lambda: set(info[0] for info in proc.open_files()) 934 logdir = self.get_temp_dir() 935 with context.eager_mode(): 936 writer = summary_ops.create_file_writer_v2(logdir) 937 files = gfile.Glob(os.path.join(logdir, '*')) 938 self.assertEqual(1, len(files)) 939 eventfile = files[0] 940 self.assertIn(eventfile, get_open_filenames()) 941 writer.close() 942 self.assertNotIn(eventfile, get_open_filenames()) 943 944 def testDereference_closesOpenFile(self): 945 try: 946 import psutil # pylint: disable=g-import-not-at-top 947 except ImportError: 948 raise unittest.SkipTest('test requires psutil') 949 proc = psutil.Process() 950 get_open_filenames = lambda: set(info[0] for info in proc.open_files()) 951 logdir = self.get_temp_dir() 952 with context.eager_mode(): 953 writer = summary_ops.create_file_writer_v2(logdir) 954 files = gfile.Glob(os.path.join(logdir, '*')) 955 self.assertEqual(1, len(files)) 956 eventfile = files[0] 957 self.assertIn(eventfile, get_open_filenames()) 958 del writer 959 self.assertNotIn(eventfile, get_open_filenames()) 960 961 962class SummaryOpsTest(test_util.TensorFlowTestCase): 963 964 def tearDown(self): 965 summary_ops.trace_off() 966 967 def exec_summary_op(self, summary_op_fn): 968 assert context.executing_eagerly() 969 logdir = self.get_temp_dir() 970 writer = summary_ops.create_file_writer(logdir) 971 with writer.as_default(): 972 summary_op_fn() 973 writer.close() 974 events = events_from_logdir(logdir) 975 return events[1] 976 977 def run_metadata(self, *args, **kwargs): 978 assert context.executing_eagerly() 979 logdir = self.get_temp_dir() 980 writer = summary_ops.create_file_writer(logdir) 981 with writer.as_default(): 982 summary_ops.run_metadata(*args, **kwargs) 983 writer.close() 984 events = events_from_logdir(logdir) 985 return events[1] 986 987 def run_metadata_graphs(self, *args, **kwargs): 988 assert context.executing_eagerly() 989 logdir = self.get_temp_dir() 990 writer = summary_ops.create_file_writer(logdir) 991 with writer.as_default(): 992 summary_ops.run_metadata_graphs(*args, **kwargs) 993 writer.close() 994 events = events_from_logdir(logdir) 995 return events[1] 996 997 def create_run_metadata(self): 998 step_stats = step_stats_pb2.StepStats(dev_stats=[ 999 step_stats_pb2.DeviceStepStats( 1000 device='cpu:0', 1001 node_stats=[step_stats_pb2.NodeExecStats(node_name='hello')]) 1002 ]) 1003 return config_pb2.RunMetadata( 1004 function_graphs=[ 1005 config_pb2.RunMetadata.FunctionGraphs( 1006 pre_optimization_graph=graph_pb2.GraphDef( 1007 node=[node_def_pb2.NodeDef(name='foo')])) 1008 ], 1009 step_stats=step_stats) 1010 1011 def run_trace(self, f, step=1): 1012 assert context.executing_eagerly() 1013 logdir = self.get_temp_dir() 1014 writer = summary_ops.create_file_writer(logdir) 1015 summary_ops.trace_on(graph=True, profiler=False) 1016 with writer.as_default(): 1017 f() 1018 summary_ops.trace_export(name='foo', step=step) 1019 writer.close() 1020 events = events_from_logdir(logdir) 1021 return events[1] 1022 1023 @test_util.run_v2_only 1024 def testRunMetadata_usesNameAsTag(self): 1025 meta = config_pb2.RunMetadata() 1026 1027 with ops.name_scope('foo', skip_on_eager=False): 1028 event = self.run_metadata(name='my_name', data=meta, step=1) 1029 first_val = event.summary.value[0] 1030 1031 self.assertEqual('foo/my_name', first_val.tag) 1032 1033 @test_util.run_v2_only 1034 def testRunMetadata_summaryMetadata(self): 1035 expected_summary_metadata = """ 1036 plugin_data { 1037 plugin_name: "graph_run_metadata" 1038 content: "1" 1039 } 1040 """ 1041 meta = config_pb2.RunMetadata() 1042 event = self.run_metadata(name='my_name', data=meta, step=1) 1043 actual_summary_metadata = event.summary.value[0].metadata 1044 self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata) 1045 1046 @test_util.run_v2_only 1047 def testRunMetadata_wholeRunMetadata(self): 1048 expected_run_metadata = """ 1049 step_stats { 1050 dev_stats { 1051 device: "cpu:0" 1052 node_stats { 1053 node_name: "hello" 1054 } 1055 } 1056 } 1057 function_graphs { 1058 pre_optimization_graph { 1059 node { 1060 name: "foo" 1061 } 1062 } 1063 } 1064 """ 1065 meta = self.create_run_metadata() 1066 event = self.run_metadata(name='my_name', data=meta, step=1) 1067 first_val = event.summary.value[0] 1068 1069 actual_run_metadata = config_pb2.RunMetadata.FromString( 1070 first_val.tensor.string_val[0]) 1071 self.assertProtoEquals(expected_run_metadata, actual_run_metadata) 1072 1073 @test_util.run_v2_only 1074 def testRunMetadata_usesDefaultStep(self): 1075 meta = config_pb2.RunMetadata() 1076 try: 1077 summary_ops.set_step(42) 1078 event = self.run_metadata(name='my_name', data=meta) 1079 self.assertEqual(42, event.step) 1080 finally: 1081 # Reset to default state for other tests. 1082 summary_ops.set_step(None) 1083 1084 @test_util.run_v2_only 1085 def testRunMetadataGraph_usesNameAsTag(self): 1086 meta = config_pb2.RunMetadata() 1087 1088 with ops.name_scope('foo', skip_on_eager=False): 1089 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 1090 first_val = event.summary.value[0] 1091 1092 self.assertEqual('foo/my_name', first_val.tag) 1093 1094 @test_util.run_v2_only 1095 def testRunMetadataGraph_summaryMetadata(self): 1096 expected_summary_metadata = """ 1097 plugin_data { 1098 plugin_name: "graph_run_metadata_graph" 1099 content: "1" 1100 } 1101 """ 1102 meta = config_pb2.RunMetadata() 1103 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 1104 actual_summary_metadata = event.summary.value[0].metadata 1105 self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata) 1106 1107 @test_util.run_v2_only 1108 def testRunMetadataGraph_runMetadataFragment(self): 1109 expected_run_metadata = """ 1110 function_graphs { 1111 pre_optimization_graph { 1112 node { 1113 name: "foo" 1114 } 1115 } 1116 } 1117 """ 1118 meta = self.create_run_metadata() 1119 1120 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 1121 first_val = event.summary.value[0] 1122 1123 actual_run_metadata = config_pb2.RunMetadata.FromString( 1124 first_val.tensor.string_val[0]) 1125 self.assertProtoEquals(expected_run_metadata, actual_run_metadata) 1126 1127 @test_util.run_v2_only 1128 def testRunMetadataGraph_usesDefaultStep(self): 1129 meta = config_pb2.RunMetadata() 1130 try: 1131 summary_ops.set_step(42) 1132 event = self.run_metadata_graphs(name='my_name', data=meta) 1133 self.assertEqual(42, event.step) 1134 finally: 1135 # Reset to default state for other tests. 1136 summary_ops.set_step(None) 1137 1138 @test_util.run_v2_only 1139 def testTrace(self): 1140 1141 @def_function.function 1142 def f(): 1143 x = constant_op.constant(2) 1144 y = constant_op.constant(3) 1145 return x**y 1146 1147 event = self.run_trace(f) 1148 1149 first_val = event.summary.value[0] 1150 actual_run_metadata = config_pb2.RunMetadata.FromString( 1151 first_val.tensor.string_val[0]) 1152 1153 # Content of function_graphs is large and, for instance, device can change. 1154 self.assertTrue(hasattr(actual_run_metadata, 'function_graphs')) 1155 1156 @test_util.run_v2_only 1157 def testTrace_cannotEnableTraceInFunction(self): 1158 1159 @def_function.function 1160 def f(): 1161 summary_ops.trace_on(graph=True, profiler=False) 1162 x = constant_op.constant(2) 1163 y = constant_op.constant(3) 1164 return x**y 1165 1166 with test.mock.patch.object(logging, 'warn') as mock_log: 1167 f() 1168 self.assertRegex( 1169 str(mock_log.call_args), 'Cannot enable trace inside a tf.function.') 1170 1171 @test_util.run_v2_only 1172 def testTrace_cannotEnableTraceInGraphMode(self): 1173 with test.mock.patch.object(logging, 'warn') as mock_log: 1174 with context.graph_mode(): 1175 summary_ops.trace_on(graph=True, profiler=False) 1176 self.assertRegex( 1177 str(mock_log.call_args), 'Must enable trace in eager mode.') 1178 1179 @test_util.run_v2_only 1180 def testTrace_cannotExportTraceWithoutTrace(self): 1181 with six.assertRaisesRegex(self, ValueError, 1182 'Must enable trace before export.'): 1183 summary_ops.trace_export(name='foo', step=1) 1184 1185 @test_util.run_v2_only 1186 def testTrace_cannotExportTraceInFunction(self): 1187 summary_ops.trace_on(graph=True, profiler=False) 1188 1189 @def_function.function 1190 def f(): 1191 x = constant_op.constant(2) 1192 y = constant_op.constant(3) 1193 summary_ops.trace_export(name='foo', step=1) 1194 return x**y 1195 1196 with test.mock.patch.object(logging, 'warn') as mock_log: 1197 f() 1198 self.assertRegex( 1199 str(mock_log.call_args), 'Cannot export trace inside a tf.function.') 1200 1201 @test_util.run_v2_only 1202 def testTrace_cannotExportTraceInGraphMode(self): 1203 with test.mock.patch.object(logging, 'warn') as mock_log: 1204 with context.graph_mode(): 1205 summary_ops.trace_export(name='foo', step=1) 1206 self.assertRegex( 1207 str(mock_log.call_args), 1208 'Can only export trace while executing eagerly.') 1209 1210 @test_util.run_v2_only 1211 def testTrace_usesDefaultStep(self): 1212 1213 @def_function.function 1214 def f(): 1215 x = constant_op.constant(2) 1216 y = constant_op.constant(3) 1217 return x**y 1218 1219 try: 1220 summary_ops.set_step(42) 1221 event = self.run_trace(f, step=None) 1222 self.assertEqual(42, event.step) 1223 finally: 1224 # Reset to default state for other tests. 1225 summary_ops.set_step(None) 1226 1227 @test_util.run_v2_only 1228 def testTrace_withProfiler(self): 1229 1230 @def_function.function 1231 def f(): 1232 x = constant_op.constant(2) 1233 y = constant_op.constant(3) 1234 return x**y 1235 1236 assert context.executing_eagerly() 1237 logdir = self.get_temp_dir() 1238 writer = summary_ops.create_file_writer(logdir) 1239 summary_ops.trace_on(graph=True, profiler=True) 1240 profiler_outdir = self.get_temp_dir() 1241 with writer.as_default(): 1242 f() 1243 summary_ops.trace_export( 1244 name='foo', step=1, profiler_outdir=profiler_outdir) 1245 writer.close() 1246 1247 @test_util.run_v2_only 1248 def testGraph_graph(self): 1249 1250 @def_function.function 1251 def f(): 1252 x = constant_op.constant(2) 1253 y = constant_op.constant(3) 1254 return x**y 1255 1256 def summary_op_fn(): 1257 summary_ops.graph(f.get_concrete_function().graph) 1258 1259 event = self.exec_summary_op(summary_op_fn) 1260 self.assertIsNotNone(event.graph_def) 1261 1262 @test_util.run_v2_only 1263 def testGraph_graphDef(self): 1264 1265 @def_function.function 1266 def f(): 1267 x = constant_op.constant(2) 1268 y = constant_op.constant(3) 1269 return x**y 1270 1271 def summary_op_fn(): 1272 summary_ops.graph(f.get_concrete_function().graph.as_graph_def()) 1273 1274 event = self.exec_summary_op(summary_op_fn) 1275 self.assertIsNotNone(event.graph_def) 1276 1277 @test_util.run_v2_only 1278 def testGraph_invalidData(self): 1279 def summary_op_fn(): 1280 summary_ops.graph('hello') 1281 1282 with self.assertRaisesRegex( 1283 ValueError, 1284 r'\'graph_data\' is not tf.Graph or tf.compat.v1.GraphDef', 1285 ): 1286 self.exec_summary_op(summary_op_fn) 1287 1288 @test_util.run_v2_only 1289 def testGraph_fromGraphMode(self): 1290 1291 @def_function.function 1292 def f(): 1293 x = constant_op.constant(2) 1294 y = constant_op.constant(3) 1295 return x**y 1296 1297 @def_function.function 1298 def g(graph): 1299 summary_ops.graph(graph) 1300 1301 def summary_op_fn(): 1302 graph_def = f.get_concrete_function().graph.as_graph_def(add_shapes=True) 1303 func_graph = constant_op.constant(graph_def.SerializeToString()) 1304 g(func_graph) 1305 1306 with self.assertRaisesRegex( 1307 ValueError, 1308 r'graph\(\) cannot be invoked inside a graph context.', 1309 ): 1310 self.exec_summary_op(summary_op_fn) 1311 1312 1313def events_from_file(filepath): 1314 """Returns all events in a single event file. 1315 1316 Args: 1317 filepath: Path to the event file. 1318 1319 Returns: 1320 A list of all tf.Event protos in the event file. 1321 """ 1322 records = list(tf_record.tf_record_iterator(filepath)) 1323 result = [] 1324 for r in records: 1325 event = event_pb2.Event() 1326 event.ParseFromString(r) 1327 result.append(event) 1328 return result 1329 1330 1331def events_from_logdir(logdir): 1332 """Returns all events in the single eventfile in logdir. 1333 1334 Args: 1335 logdir: The directory in which the single event file is sought. 1336 1337 Returns: 1338 A list of all tf.Event protos from the single event file. 1339 1340 Raises: 1341 AssertionError: If logdir does not contain exactly one file. 1342 """ 1343 assert gfile.Exists(logdir) 1344 files = gfile.ListDirectory(logdir) 1345 assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files 1346 return events_from_file(os.path.join(logdir, files[0])) 1347 1348 1349def to_numpy(summary_value): 1350 return tensor_util.MakeNdarray(summary_value.tensor) 1351 1352 1353if __name__ == '__main__': 1354 test.main() 1355