1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for V2 summary ops from summary_ops_v2.""" 16 17import os 18import unittest 19 20from tensorflow.core.framework import graph_pb2 21from tensorflow.core.framework import node_def_pb2 22from tensorflow.core.framework import step_stats_pb2 23from tensorflow.core.framework import summary_pb2 24from tensorflow.core.protobuf import config_pb2 25from tensorflow.core.util import event_pb2 26from tensorflow.python.eager import context 27from tensorflow.python.eager import def_function 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_spec 33from tensorflow.python.framework import tensor_util 34from tensorflow.python.framework import test_util 35from tensorflow.python.lib.io import tf_record 36from tensorflow.python.module import module 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import summary_ops_v2 as summary_ops 39from tensorflow.python.ops import variables 40from tensorflow.python.platform import gfile 41from tensorflow.python.platform import test 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.saved_model import load as saved_model_load 44from tensorflow.python.saved_model import loader as saved_model_loader 45from tensorflow.python.saved_model import save as saved_model_save 46from tensorflow.python.saved_model import tag_constants 47 48 49class SummaryOpsCoreTest(test_util.TensorFlowTestCase): 50 51 def testWrite(self): 52 logdir = self.get_temp_dir() 53 with context.eager_mode(): 54 with summary_ops.create_file_writer_v2(logdir).as_default(): 55 output = summary_ops.write('tag', 42, step=12) 56 self.assertTrue(output.numpy()) 57 events = events_from_logdir(logdir) 58 self.assertEqual(2, len(events)) 59 self.assertEqual(12, events[1].step) 60 value = events[1].summary.value[0] 61 self.assertEqual('tag', value.tag) 62 self.assertEqual(42, to_numpy(value)) 63 64 def testWrite_fromFunction(self): 65 logdir = self.get_temp_dir() 66 with context.eager_mode(): 67 writer = summary_ops.create_file_writer_v2(logdir) 68 @def_function.function 69 def f(): 70 with writer.as_default(): 71 return summary_ops.write('tag', 42, step=12) 72 output = f() 73 self.assertTrue(output.numpy()) 74 events = events_from_logdir(logdir) 75 self.assertEqual(2, len(events)) 76 self.assertEqual(12, events[1].step) 77 value = events[1].summary.value[0] 78 self.assertEqual('tag', value.tag) 79 self.assertEqual(42, to_numpy(value)) 80 81 def testWrite_metadata(self): 82 logdir = self.get_temp_dir() 83 metadata = summary_pb2.SummaryMetadata() 84 metadata.plugin_data.plugin_name = 'foo' 85 with context.eager_mode(): 86 with summary_ops.create_file_writer_v2(logdir).as_default(): 87 summary_ops.write('obj', 0, 0, metadata=metadata) 88 summary_ops.write('bytes', 0, 0, metadata=metadata.SerializeToString()) 89 m = constant_op.constant(metadata.SerializeToString()) 90 summary_ops.write('string_tensor', 0, 0, metadata=m) 91 events = events_from_logdir(logdir) 92 self.assertEqual(4, len(events)) 93 self.assertEqual(metadata, events[1].summary.value[0].metadata) 94 self.assertEqual(metadata, events[2].summary.value[0].metadata) 95 self.assertEqual(metadata, events[3].summary.value[0].metadata) 96 97 def testWrite_name(self): 98 @def_function.function 99 def f(): 100 output = summary_ops.write('tag', 42, step=12, name='anonymous') 101 self.assertTrue(output.name.startswith('anonymous')) 102 f() 103 104 def testWrite_ndarray(self): 105 logdir = self.get_temp_dir() 106 with context.eager_mode(): 107 with summary_ops.create_file_writer_v2(logdir).as_default(): 108 summary_ops.write('tag', [[1, 2], [3, 4]], step=12) 109 events = events_from_logdir(logdir) 110 value = events[1].summary.value[0] 111 self.assertAllEqual([[1, 2], [3, 4]], to_numpy(value)) 112 113 def testWrite_tensor(self): 114 logdir = self.get_temp_dir() 115 with context.eager_mode(): 116 t = constant_op.constant([[1, 2], [3, 4]]) 117 with summary_ops.create_file_writer_v2(logdir).as_default(): 118 summary_ops.write('tag', t, step=12) 119 expected = t.numpy() 120 events = events_from_logdir(logdir) 121 value = events[1].summary.value[0] 122 self.assertAllEqual(expected, to_numpy(value)) 123 124 def testWrite_tensor_fromFunction(self): 125 logdir = self.get_temp_dir() 126 with context.eager_mode(): 127 writer = summary_ops.create_file_writer_v2(logdir) 128 @def_function.function 129 def f(t): 130 with writer.as_default(): 131 summary_ops.write('tag', t, step=12) 132 t = constant_op.constant([[1, 2], [3, 4]]) 133 f(t) 134 expected = t.numpy() 135 events = events_from_logdir(logdir) 136 value = events[1].summary.value[0] 137 self.assertAllEqual(expected, to_numpy(value)) 138 139 def testWrite_stringTensor(self): 140 logdir = self.get_temp_dir() 141 with context.eager_mode(): 142 with summary_ops.create_file_writer_v2(logdir).as_default(): 143 summary_ops.write('tag', [b'foo', b'bar'], step=12) 144 events = events_from_logdir(logdir) 145 value = events[1].summary.value[0] 146 self.assertAllEqual([b'foo', b'bar'], to_numpy(value)) 147 148 @test_util.run_gpu_only 149 def testWrite_gpuDeviceContext(self): 150 logdir = self.get_temp_dir() 151 with context.eager_mode(): 152 with summary_ops.create_file_writer_v2(logdir).as_default(): 153 with ops.device('/GPU:0'): 154 value = constant_op.constant(42.0) 155 step = constant_op.constant(12, dtype=dtypes.int64) 156 summary_ops.write('tag', value, step=step).numpy() 157 empty_metadata = summary_pb2.SummaryMetadata() 158 events = events_from_logdir(logdir) 159 self.assertEqual(2, len(events)) 160 self.assertEqual(12, events[1].step) 161 self.assertEqual(42, to_numpy(events[1].summary.value[0])) 162 self.assertEqual(empty_metadata, events[1].summary.value[0].metadata) 163 164 @test_util.also_run_as_tf_function 165 def testWrite_noDefaultWriter(self): 166 # Use assertAllEqual instead of assertFalse since it works in a defun. 167 self.assertAllEqual(False, summary_ops.write('tag', 42, step=0)) 168 169 @test_util.also_run_as_tf_function 170 def testWrite_noStep_okayIfAlsoNoDefaultWriter(self): 171 # Use assertAllEqual instead of assertFalse since it works in a defun. 172 self.assertAllEqual(False, summary_ops.write('tag', 42)) 173 174 def testWrite_noStep(self): 175 logdir = self.get_temp_dir() 176 with context.eager_mode(): 177 with summary_ops.create_file_writer_v2(logdir).as_default(): 178 with self.assertRaisesRegex(ValueError, 'No step set'): 179 summary_ops.write('tag', 42) 180 181 def testWrite_noStep_okayIfNotRecordingSummaries(self): 182 logdir = self.get_temp_dir() 183 with context.eager_mode(): 184 with summary_ops.create_file_writer_v2(logdir).as_default(): 185 with summary_ops.record_if(False): 186 self.assertFalse(summary_ops.write('tag', 42)) 187 188 def testWrite_usingDefaultStep(self): 189 logdir = self.get_temp_dir() 190 try: 191 with context.eager_mode(): 192 with summary_ops.create_file_writer_v2(logdir).as_default(): 193 summary_ops.set_step(1) 194 summary_ops.write('tag', 1.0) 195 summary_ops.set_step(2) 196 summary_ops.write('tag', 1.0) 197 mystep = variables.Variable(10, dtype=dtypes.int64) 198 summary_ops.set_step(mystep) 199 summary_ops.write('tag', 1.0) 200 mystep.assign_add(1) 201 summary_ops.write('tag', 1.0) 202 events = events_from_logdir(logdir) 203 self.assertEqual(5, len(events)) 204 self.assertEqual(1, events[1].step) 205 self.assertEqual(2, events[2].step) 206 self.assertEqual(10, events[3].step) 207 self.assertEqual(11, events[4].step) 208 finally: 209 # Reset to default state for other tests. 210 summary_ops.set_step(None) 211 212 def testWrite_usingDefaultStepConstant_fromFunction(self): 213 logdir = self.get_temp_dir() 214 try: 215 with context.eager_mode(): 216 writer = summary_ops.create_file_writer_v2(logdir) 217 @def_function.function 218 def f(): 219 with writer.as_default(): 220 summary_ops.write('tag', 1.0) 221 summary_ops.set_step(1) 222 f() 223 summary_ops.set_step(2) 224 f() 225 events = events_from_logdir(logdir) 226 self.assertEqual(3, len(events)) 227 self.assertEqual(1, events[1].step) 228 # The step value will still be 1 because the value was captured at the 229 # time the function was first traced. 230 self.assertEqual(1, events[2].step) 231 finally: 232 # Reset to default state for other tests. 233 summary_ops.set_step(None) 234 235 def testWrite_usingDefaultStepVariable_fromFunction(self): 236 logdir = self.get_temp_dir() 237 try: 238 with context.eager_mode(): 239 writer = summary_ops.create_file_writer_v2(logdir) 240 @def_function.function 241 def f(): 242 with writer.as_default(): 243 summary_ops.write('tag', 1.0) 244 mystep = variables.Variable(0, dtype=dtypes.int64) 245 summary_ops.set_step(mystep) 246 f() 247 mystep.assign_add(1) 248 f() 249 mystep.assign(10) 250 f() 251 events = events_from_logdir(logdir) 252 self.assertEqual(4, len(events)) 253 self.assertEqual(0, events[1].step) 254 self.assertEqual(1, events[2].step) 255 self.assertEqual(10, events[3].step) 256 finally: 257 # Reset to default state for other tests. 258 summary_ops.set_step(None) 259 260 def testWrite_usingDefaultStepConstant_fromLegacyGraph(self): 261 logdir = self.get_temp_dir() 262 try: 263 with context.graph_mode(): 264 writer = summary_ops.create_file_writer_v2(logdir) 265 summary_ops.set_step(1) 266 with writer.as_default(): 267 write_op = summary_ops.write('tag', 1.0) 268 summary_ops.set_step(2) 269 with self.cached_session() as sess: 270 sess.run(writer.init()) 271 sess.run(write_op) 272 sess.run(write_op) 273 sess.run(writer.flush()) 274 events = events_from_logdir(logdir) 275 self.assertEqual(3, len(events)) 276 self.assertEqual(1, events[1].step) 277 # The step value will still be 1 because the value was captured at the 278 # time the graph was constructed. 279 self.assertEqual(1, events[2].step) 280 finally: 281 # Reset to default state for other tests. 282 summary_ops.set_step(None) 283 284 def testWrite_usingDefaultStepVariable_fromLegacyGraph(self): 285 logdir = self.get_temp_dir() 286 try: 287 with context.graph_mode(): 288 writer = summary_ops.create_file_writer_v2(logdir) 289 mystep = variables.Variable(0, dtype=dtypes.int64) 290 summary_ops.set_step(mystep) 291 with writer.as_default(): 292 write_op = summary_ops.write('tag', 1.0) 293 first_assign_op = mystep.assign_add(1) 294 second_assign_op = mystep.assign(10) 295 with self.cached_session() as sess: 296 sess.run(writer.init()) 297 sess.run(mystep.initializer) 298 sess.run(write_op) 299 sess.run(first_assign_op) 300 sess.run(write_op) 301 sess.run(second_assign_op) 302 sess.run(write_op) 303 sess.run(writer.flush()) 304 events = events_from_logdir(logdir) 305 self.assertEqual(4, len(events)) 306 self.assertEqual(0, events[1].step) 307 self.assertEqual(1, events[2].step) 308 self.assertEqual(10, events[3].step) 309 finally: 310 # Reset to default state for other tests. 311 summary_ops.set_step(None) 312 313 def testWrite_usingDefaultStep_fromAsDefault(self): 314 logdir = self.get_temp_dir() 315 try: 316 with context.eager_mode(): 317 writer = summary_ops.create_file_writer_v2(logdir) 318 with writer.as_default(step=1): 319 summary_ops.write('tag', 1.0) 320 with writer.as_default(): 321 summary_ops.write('tag', 1.0) 322 with writer.as_default(step=2): 323 summary_ops.write('tag', 1.0) 324 summary_ops.write('tag', 1.0) 325 summary_ops.set_step(3) 326 summary_ops.write('tag', 1.0) 327 events = events_from_logdir(logdir) 328 self.assertListEqual([1, 1, 2, 1, 3], [e.step for e in events[1:]]) 329 finally: 330 # Reset to default state for other tests. 331 summary_ops.set_step(None) 332 333 def testWrite_usingDefaultStepVariable_fromAsDefault(self): 334 logdir = self.get_temp_dir() 335 try: 336 with context.eager_mode(): 337 writer = summary_ops.create_file_writer_v2(logdir) 338 mystep = variables.Variable(1, dtype=dtypes.int64) 339 with writer.as_default(step=mystep): 340 summary_ops.write('tag', 1.0) 341 with writer.as_default(): 342 mystep.assign(2) 343 summary_ops.write('tag', 1.0) 344 with writer.as_default(step=3): 345 summary_ops.write('tag', 1.0) 346 summary_ops.write('tag', 1.0) 347 mystep.assign(4) 348 summary_ops.write('tag', 1.0) 349 events = events_from_logdir(logdir) 350 self.assertListEqual([1, 2, 3, 2, 4], [e.step for e in events[1:]]) 351 finally: 352 # Reset to default state for other tests. 353 summary_ops.set_step(None) 354 355 def testWrite_usingDefaultStep_fromSetAsDefault(self): 356 logdir = self.get_temp_dir() 357 try: 358 with context.eager_mode(): 359 writer = summary_ops.create_file_writer_v2(logdir) 360 mystep = variables.Variable(1, dtype=dtypes.int64) 361 writer.set_as_default(step=mystep) 362 summary_ops.write('tag', 1.0) 363 mystep.assign(2) 364 summary_ops.write('tag', 1.0) 365 writer.set_as_default(step=3) 366 summary_ops.write('tag', 1.0) 367 writer.flush() 368 events = events_from_logdir(logdir) 369 self.assertListEqual([1, 2, 3], [e.step for e in events[1:]]) 370 finally: 371 # Reset to default state for other tests. 372 summary_ops.set_step(None) 373 374 def testWrite_usingDefaultStepVariable_fromSetAsDefault(self): 375 logdir = self.get_temp_dir() 376 try: 377 with context.eager_mode(): 378 writer = summary_ops.create_file_writer_v2(logdir) 379 writer.set_as_default(step=1) 380 summary_ops.write('tag', 1.0) 381 writer.set_as_default(step=2) 382 summary_ops.write('tag', 1.0) 383 writer.set_as_default() 384 summary_ops.write('tag', 1.0) 385 writer.flush() 386 events = events_from_logdir(logdir) 387 self.assertListEqual([1, 2, 2], [e.step for e in events[1:]]) 388 finally: 389 # Reset to default state for other tests. 390 summary_ops.set_step(None) 391 392 def testWrite_recordIf_constant(self): 393 logdir = self.get_temp_dir() 394 with context.eager_mode(): 395 with summary_ops.create_file_writer_v2(logdir).as_default(): 396 self.assertTrue(summary_ops.write('default', 1, step=0)) 397 with summary_ops.record_if(True): 398 self.assertTrue(summary_ops.write('set_on', 1, step=0)) 399 with summary_ops.record_if(False): 400 self.assertFalse(summary_ops.write('set_off', 1, step=0)) 401 events = events_from_logdir(logdir) 402 self.assertEqual(3, len(events)) 403 self.assertEqual('default', events[1].summary.value[0].tag) 404 self.assertEqual('set_on', events[2].summary.value[0].tag) 405 406 def testWrite_recordIf_constant_fromFunction(self): 407 logdir = self.get_temp_dir() 408 with context.eager_mode(): 409 writer = summary_ops.create_file_writer_v2(logdir) 410 @def_function.function 411 def f(): 412 with writer.as_default(): 413 # Use assertAllEqual instead of assertTrue since it works in a defun. 414 self.assertAllEqual(summary_ops.write('default', 1, step=0), True) 415 with summary_ops.record_if(True): 416 self.assertAllEqual(summary_ops.write('set_on', 1, step=0), True) 417 with summary_ops.record_if(False): 418 self.assertAllEqual(summary_ops.write('set_off', 1, step=0), False) 419 f() 420 events = events_from_logdir(logdir) 421 self.assertEqual(3, len(events)) 422 self.assertEqual('default', events[1].summary.value[0].tag) 423 self.assertEqual('set_on', events[2].summary.value[0].tag) 424 425 def testWrite_recordIf_callable(self): 426 logdir = self.get_temp_dir() 427 with context.eager_mode(): 428 step = variables.Variable(-1, dtype=dtypes.int64) 429 def record_fn(): 430 step.assign_add(1) 431 return int(step % 2) == 0 432 with summary_ops.create_file_writer_v2(logdir).as_default(): 433 with summary_ops.record_if(record_fn): 434 self.assertTrue(summary_ops.write('tag', 1, step=step)) 435 self.assertFalse(summary_ops.write('tag', 1, step=step)) 436 self.assertTrue(summary_ops.write('tag', 1, step=step)) 437 self.assertFalse(summary_ops.write('tag', 1, step=step)) 438 self.assertTrue(summary_ops.write('tag', 1, step=step)) 439 events = events_from_logdir(logdir) 440 self.assertEqual(4, len(events)) 441 self.assertEqual(0, events[1].step) 442 self.assertEqual(2, events[2].step) 443 self.assertEqual(4, events[3].step) 444 445 def testWrite_recordIf_callable_fromFunction(self): 446 logdir = self.get_temp_dir() 447 with context.eager_mode(): 448 writer = summary_ops.create_file_writer_v2(logdir) 449 step = variables.Variable(-1, dtype=dtypes.int64) 450 @def_function.function 451 def record_fn(): 452 step.assign_add(1) 453 return math_ops.equal(step % 2, 0) 454 @def_function.function 455 def f(): 456 with writer.as_default(): 457 with summary_ops.record_if(record_fn): 458 return [ 459 summary_ops.write('tag', 1, step=step), 460 summary_ops.write('tag', 1, step=step), 461 summary_ops.write('tag', 1, step=step)] 462 self.assertAllEqual(f(), [True, False, True]) 463 self.assertAllEqual(f(), [False, True, False]) 464 events = events_from_logdir(logdir) 465 self.assertEqual(4, len(events)) 466 self.assertEqual(0, events[1].step) 467 self.assertEqual(2, events[2].step) 468 self.assertEqual(4, events[3].step) 469 470 def testWrite_recordIf_tensorInput_fromFunction(self): 471 logdir = self.get_temp_dir() 472 with context.eager_mode(): 473 writer = summary_ops.create_file_writer_v2(logdir) 474 @def_function.function(input_signature=[ 475 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)]) 476 def f(step): 477 with writer.as_default(): 478 with summary_ops.record_if(math_ops.equal(step % 2, 0)): 479 return summary_ops.write('tag', 1, step=step) 480 self.assertTrue(f(0)) 481 self.assertFalse(f(1)) 482 self.assertTrue(f(2)) 483 self.assertFalse(f(3)) 484 self.assertTrue(f(4)) 485 events = events_from_logdir(logdir) 486 self.assertEqual(4, len(events)) 487 self.assertEqual(0, events[1].step) 488 self.assertEqual(2, events[2].step) 489 self.assertEqual(4, events[3].step) 490 491 def testWriteRawPb(self): 492 logdir = self.get_temp_dir() 493 pb = summary_pb2.Summary() 494 pb.value.add().simple_value = 42.0 495 with context.eager_mode(): 496 with summary_ops.create_file_writer_v2(logdir).as_default(): 497 output = summary_ops.write_raw_pb(pb.SerializeToString(), step=12) 498 self.assertTrue(output.numpy()) 499 events = events_from_logdir(logdir) 500 self.assertEqual(2, len(events)) 501 self.assertEqual(12, events[1].step) 502 self.assertProtoEquals(pb, events[1].summary) 503 504 def testWriteRawPb_fromFunction(self): 505 logdir = self.get_temp_dir() 506 pb = summary_pb2.Summary() 507 pb.value.add().simple_value = 42.0 508 with context.eager_mode(): 509 writer = summary_ops.create_file_writer_v2(logdir) 510 @def_function.function 511 def f(): 512 with writer.as_default(): 513 return summary_ops.write_raw_pb(pb.SerializeToString(), step=12) 514 output = f() 515 self.assertTrue(output.numpy()) 516 events = events_from_logdir(logdir) 517 self.assertEqual(2, len(events)) 518 self.assertEqual(12, events[1].step) 519 self.assertProtoEquals(pb, events[1].summary) 520 521 def testWriteRawPb_multipleValues(self): 522 logdir = self.get_temp_dir() 523 pb1 = summary_pb2.Summary() 524 pb1.value.add().simple_value = 1.0 525 pb1.value.add().simple_value = 2.0 526 pb2 = summary_pb2.Summary() 527 pb2.value.add().simple_value = 3.0 528 pb3 = summary_pb2.Summary() 529 pb3.value.add().simple_value = 4.0 530 pb3.value.add().simple_value = 5.0 531 pb3.value.add().simple_value = 6.0 532 pbs = [pb.SerializeToString() for pb in (pb1, pb2, pb3)] 533 with context.eager_mode(): 534 with summary_ops.create_file_writer_v2(logdir).as_default(): 535 output = summary_ops.write_raw_pb(pbs, step=12) 536 self.assertTrue(output.numpy()) 537 events = events_from_logdir(logdir) 538 self.assertEqual(2, len(events)) 539 self.assertEqual(12, events[1].step) 540 expected_pb = summary_pb2.Summary() 541 for i in range(6): 542 expected_pb.value.add().simple_value = i + 1.0 543 self.assertProtoEquals(expected_pb, events[1].summary) 544 545 def testWriteRawPb_invalidValue(self): 546 logdir = self.get_temp_dir() 547 with context.eager_mode(): 548 with summary_ops.create_file_writer_v2(logdir).as_default(): 549 with self.assertRaisesRegex( 550 errors.DataLossError, 551 'Bad tf.compat.v1.Summary binary proto tensor string'): 552 summary_ops.write_raw_pb('notaproto', step=12) 553 554 @test_util.also_run_as_tf_function 555 def testGetSetStep(self): 556 try: 557 self.assertIsNone(summary_ops.get_step()) 558 summary_ops.set_step(1) 559 # Use assertAllEqual instead of assertEqual since it works in a defun. 560 self.assertAllEqual(1, summary_ops.get_step()) 561 summary_ops.set_step(constant_op.constant(2)) 562 self.assertAllEqual(2, summary_ops.get_step()) 563 finally: 564 # Reset to default state for other tests. 565 summary_ops.set_step(None) 566 567 def testGetSetStep_variable(self): 568 with context.eager_mode(): 569 try: 570 mystep = variables.Variable(0) 571 summary_ops.set_step(mystep) 572 self.assertAllEqual(0, summary_ops.get_step().read_value()) 573 mystep.assign_add(1) 574 self.assertAllEqual(1, summary_ops.get_step().read_value()) 575 # Check that set_step() properly maintains reference to variable. 576 del mystep 577 self.assertAllEqual(1, summary_ops.get_step().read_value()) 578 summary_ops.get_step().assign_add(1) 579 self.assertAllEqual(2, summary_ops.get_step().read_value()) 580 finally: 581 # Reset to default state for other tests. 582 summary_ops.set_step(None) 583 584 def testGetSetStep_variable_fromFunction(self): 585 with context.eager_mode(): 586 try: 587 @def_function.function 588 def set_step(step): 589 summary_ops.set_step(step) 590 return summary_ops.get_step() 591 @def_function.function 592 def get_and_increment(): 593 summary_ops.get_step().assign_add(1) 594 return summary_ops.get_step() 595 mystep = variables.Variable(0) 596 self.assertAllEqual(0, set_step(mystep)) 597 self.assertAllEqual(0, summary_ops.get_step().read_value()) 598 self.assertAllEqual(1, get_and_increment()) 599 self.assertAllEqual(2, get_and_increment()) 600 # Check that set_step() properly maintains reference to variable. 601 del mystep 602 self.assertAllEqual(3, get_and_increment()) 603 finally: 604 # Reset to default state for other tests. 605 summary_ops.set_step(None) 606 607 @test_util.also_run_as_tf_function 608 def testSummaryScope(self): 609 with summary_ops.summary_scope('foo') as (tag, scope): 610 self.assertEqual('foo', tag) 611 self.assertEqual('foo/', scope) 612 with summary_ops.summary_scope('bar') as (tag, scope): 613 self.assertEqual('foo/bar', tag) 614 self.assertEqual('foo/bar/', scope) 615 with summary_ops.summary_scope('with/slash') as (tag, scope): 616 self.assertEqual('foo/with/slash', tag) 617 self.assertEqual('foo/with/slash/', scope) 618 with ops.name_scope(None, skip_on_eager=False): 619 with summary_ops.summary_scope('unnested') as (tag, scope): 620 self.assertEqual('unnested', tag) 621 self.assertEqual('unnested/', scope) 622 623 @test_util.also_run_as_tf_function 624 def testSummaryScope_defaultName(self): 625 with summary_ops.summary_scope(None) as (tag, scope): 626 self.assertEqual('summary', tag) 627 self.assertEqual('summary/', scope) 628 with summary_ops.summary_scope(None, 'backup') as (tag, scope): 629 self.assertEqual('backup', tag) 630 self.assertEqual('backup/', scope) 631 632 @test_util.also_run_as_tf_function 633 def testSummaryScope_handlesCharactersIllegalForScope(self): 634 with summary_ops.summary_scope('f?o?o') as (tag, scope): 635 self.assertEqual('f?o?o', tag) 636 self.assertEqual('foo/', scope) 637 # If all characters aren't legal for a scope name, use default name. 638 with summary_ops.summary_scope('???', 'backup') as (tag, scope): 639 self.assertEqual('???', tag) 640 self.assertEqual('backup/', scope) 641 642 @test_util.also_run_as_tf_function 643 def testSummaryScope_nameNotUniquifiedForTag(self): 644 constant_op.constant(0, name='foo') 645 with summary_ops.summary_scope('foo') as (tag, _): 646 self.assertEqual('foo', tag) 647 with summary_ops.summary_scope('foo') as (tag, _): 648 self.assertEqual('foo', tag) 649 with ops.name_scope('with', skip_on_eager=False): 650 constant_op.constant(0, name='slash') 651 with summary_ops.summary_scope('with/slash') as (tag, _): 652 self.assertEqual('with/slash', tag) 653 654 def testAllV2SummaryOps(self): 655 logdir = self.get_temp_dir() 656 def define_ops(): 657 result = [] 658 # TF 2.0 summary ops 659 result.append(summary_ops.write('write', 1, step=0)) 660 result.append(summary_ops.write_raw_pb(b'', step=0, name='raw_pb')) 661 # TF 1.x tf.contrib.summary ops 662 result.append(summary_ops.generic('tensor', 1, step=1)) 663 result.append(summary_ops.scalar('scalar', 2.0, step=1)) 664 result.append(summary_ops.histogram('histogram', [1.0], step=1)) 665 result.append(summary_ops.image('image', [[[[1.0]]]], step=1)) 666 result.append(summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1)) 667 return result 668 with context.graph_mode(): 669 ops_without_writer = define_ops() 670 with summary_ops.create_file_writer_v2(logdir).as_default(): 671 with summary_ops.record_if(True): 672 ops_recording_on = define_ops() 673 with summary_ops.record_if(False): 674 ops_recording_off = define_ops() 675 # We should be collecting all ops defined with a default writer present, 676 # regardless of whether recording was set on or off, but not those defined 677 # without a writer at all. 678 del ops_without_writer 679 expected_ops = ops_recording_on + ops_recording_off 680 self.assertCountEqual(expected_ops, summary_ops.all_v2_summary_ops()) 681 682 def testShouldRecordSummaries_defaultState(self): 683 logdir = self.get_temp_dir() 684 with context.eager_mode(): 685 self.assertAllEqual(False, summary_ops.should_record_summaries()) 686 w = summary_ops.create_file_writer_v2(logdir) 687 self.assertAllEqual(False, summary_ops.should_record_summaries()) 688 with w.as_default(): 689 # Should be enabled only when default writer is registered. 690 self.assertAllEqual(True, summary_ops.should_record_summaries()) 691 self.assertAllEqual(False, summary_ops.should_record_summaries()) 692 with summary_ops.record_if(True): 693 # Should be disabled when no default writer, even with record_if(True). 694 self.assertAllEqual(False, summary_ops.should_record_summaries()) 695 696 def testShouldRecordSummaries_constants(self): 697 logdir = self.get_temp_dir() 698 with context.eager_mode(): 699 with summary_ops.create_file_writer_v2(logdir).as_default(): 700 with summary_ops.record_if(True): 701 self.assertAllEqual(True, summary_ops.should_record_summaries()) 702 with summary_ops.record_if(False): 703 self.assertAllEqual(False, summary_ops.should_record_summaries()) 704 with summary_ops.record_if(True): 705 self.assertAllEqual(True, summary_ops.should_record_summaries()) 706 707 def testShouldRecordSummaries_variable(self): 708 logdir = self.get_temp_dir() 709 with context.eager_mode(): 710 with summary_ops.create_file_writer_v2(logdir).as_default(): 711 cond = variables.Variable(False) 712 with summary_ops.record_if(cond): 713 self.assertAllEqual(False, summary_ops.should_record_summaries()) 714 cond.assign(True) 715 self.assertAllEqual(True, summary_ops.should_record_summaries()) 716 717 def testShouldRecordSummaries_callable(self): 718 logdir = self.get_temp_dir() 719 with context.eager_mode(): 720 with summary_ops.create_file_writer_v2(logdir).as_default(): 721 cond_box = [False] 722 cond = lambda: cond_box[0] 723 with summary_ops.record_if(cond): 724 self.assertAllEqual(False, summary_ops.should_record_summaries()) 725 cond_box[0] = True 726 self.assertAllEqual(True, summary_ops.should_record_summaries()) 727 728 def testShouldRecordSummaries_fromFunction(self): 729 logdir = self.get_temp_dir() 730 with context.eager_mode(): 731 writer = summary_ops.create_file_writer_v2(logdir) 732 @def_function.function(input_signature=[ 733 tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool)]) 734 def f(cond): 735 results = [] 736 results.append(summary_ops.should_record_summaries()) 737 with writer.as_default(): 738 results.append(summary_ops.should_record_summaries()) 739 with summary_ops.record_if(False): 740 results.append(summary_ops.should_record_summaries()) 741 with summary_ops.record_if(cond): 742 results.append(summary_ops.should_record_summaries()) 743 return results 744 self.assertAllEqual([False, True, False, True], f(True)) 745 self.assertAllEqual([False, True, False, False], f(False)) 746 747 def testHasDefaultWriter_checkWriter(self): 748 logdir = self.get_temp_dir() 749 with context.eager_mode(): 750 with self.subTest(name='has_writer'): 751 with summary_ops.create_file_writer_v2(logdir).as_default(): 752 self.assertTrue(summary_ops.has_default_writer()) 753 with self.subTest(name='no_writer'): 754 self.assertFalse(summary_ops.has_default_writer()) 755 756 757class SummaryWriterTest(test_util.TensorFlowTestCase): 758 759 def testCreate_withInitAndClose(self): 760 logdir = self.get_temp_dir() 761 with context.eager_mode(): 762 writer = summary_ops.create_file_writer_v2( 763 logdir, max_queue=1000, flush_millis=1000000) 764 get_total = lambda: len(events_from_logdir(logdir)) 765 self.assertEqual(1, get_total()) # file_version Event 766 # Calling init() again while writer is open has no effect 767 writer.init() 768 self.assertEqual(1, get_total()) 769 with writer.as_default(): 770 summary_ops.write('tag', 1, step=0) 771 self.assertEqual(1, get_total()) 772 # Calling .close() should do an implicit flush 773 writer.close() 774 self.assertEqual(2, get_total()) 775 776 def testCreate_fromFunction(self): 777 logdir = self.get_temp_dir() 778 @def_function.function 779 def f(): 780 # Returned SummaryWriter must be stored in a non-local variable so it 781 # lives throughout the function execution. 782 if not hasattr(f, 'writer'): 783 f.writer = summary_ops.create_file_writer_v2(logdir) 784 with context.eager_mode(): 785 f() 786 event_files = gfile.Glob(os.path.join(logdir, '*')) 787 self.assertEqual(1, len(event_files)) 788 789 def testCreate_graphTensorArgument_raisesError(self): 790 logdir = self.get_temp_dir() 791 with context.graph_mode(): 792 logdir_tensor = constant_op.constant(logdir) 793 with context.eager_mode(): 794 with self.assertRaisesRegex( 795 ValueError, 'Invalid graph Tensor argument.*logdir'): 796 summary_ops.create_file_writer_v2(logdir_tensor) 797 self.assertEmpty(gfile.Glob(os.path.join(logdir, '*'))) 798 799 def testCreate_fromFunction_graphTensorArgument_raisesError(self): 800 logdir = self.get_temp_dir() 801 @def_function.function 802 def f(): 803 summary_ops.create_file_writer_v2(constant_op.constant(logdir)) 804 with context.eager_mode(): 805 with self.assertRaisesRegex( 806 ValueError, 'Invalid graph Tensor argument.*logdir'): 807 f() 808 self.assertEmpty(gfile.Glob(os.path.join(logdir, '*'))) 809 810 def testCreate_fromFunction_unpersistedResource_raisesError(self): 811 logdir = self.get_temp_dir() 812 @def_function.function 813 def f(): 814 with summary_ops.create_file_writer_v2(logdir).as_default(): 815 pass # Calling .as_default() is enough to indicate use. 816 with context.eager_mode(): 817 # TODO(nickfelt): change this to a better error 818 with self.assertRaisesRegex( 819 errors.NotFoundError, 'Resource.*does not exist'): 820 f() 821 # Even though we didn't use it, an event file will have been created. 822 self.assertEqual(1, len(gfile.Glob(os.path.join(logdir, '*')))) 823 824 def testCreate_immediateSetAsDefault_retainsReference(self): 825 logdir = self.get_temp_dir() 826 try: 827 with context.eager_mode(): 828 summary_ops.create_file_writer_v2(logdir).set_as_default() 829 summary_ops.flush() 830 finally: 831 # Ensure we clean up no matter how the test executes. 832 summary_ops._summary_state.writer = None # pylint: disable=protected-access 833 834 def testCreate_immediateAsDefault_retainsReference(self): 835 logdir = self.get_temp_dir() 836 with context.eager_mode(): 837 with summary_ops.create_file_writer_v2(logdir).as_default(): 838 summary_ops.flush() 839 840 def testCreate_avoidsFilenameCollision(self): 841 logdir = self.get_temp_dir() 842 with context.eager_mode(): 843 for _ in range(10): 844 summary_ops.create_file_writer_v2(logdir) 845 event_files = gfile.Glob(os.path.join(logdir, '*')) 846 self.assertLen(event_files, 10) 847 848 def testCreate_graphMode_avoidsFilenameCollision(self): 849 logdir = self.get_temp_dir() 850 with context.graph_mode(), ops.Graph().as_default(): 851 writer = summary_ops.create_file_writer_v2(logdir) 852 with self.cached_session() as sess: 853 for _ in range(10): 854 sess.run(writer.init()) 855 sess.run(writer.close()) 856 event_files = gfile.Glob(os.path.join(logdir, '*')) 857 self.assertLen(event_files, 10) 858 859 def testNoSharing(self): 860 # Two writers with the same logdir should not share state. 861 logdir = self.get_temp_dir() 862 with context.eager_mode(): 863 writer1 = summary_ops.create_file_writer_v2(logdir) 864 with writer1.as_default(): 865 summary_ops.write('tag', 1, step=1) 866 event_files = gfile.Glob(os.path.join(logdir, '*')) 867 self.assertEqual(1, len(event_files)) 868 file1 = event_files[0] 869 870 writer2 = summary_ops.create_file_writer_v2(logdir) 871 with writer2.as_default(): 872 summary_ops.write('tag', 1, step=2) 873 event_files = gfile.Glob(os.path.join(logdir, '*')) 874 self.assertEqual(2, len(event_files)) 875 event_files.remove(file1) 876 file2 = event_files[0] 877 878 # Extra writes to ensure interleaved usage works. 879 with writer1.as_default(): 880 summary_ops.write('tag', 1, step=1) 881 with writer2.as_default(): 882 summary_ops.write('tag', 1, step=2) 883 884 events = iter(events_from_file(file1)) 885 self.assertEqual('brain.Event:2', next(events).file_version) 886 self.assertEqual(1, next(events).step) 887 self.assertEqual(1, next(events).step) 888 self.assertRaises(StopIteration, lambda: next(events)) 889 events = iter(events_from_file(file2)) 890 self.assertEqual('brain.Event:2', next(events).file_version) 891 self.assertEqual(2, next(events).step) 892 self.assertEqual(2, next(events).step) 893 self.assertRaises(StopIteration, lambda: next(events)) 894 895 def testNoSharing_fromFunction(self): 896 logdir = self.get_temp_dir() 897 @def_function.function 898 def f1(): 899 if not hasattr(f1, 'writer'): 900 f1.writer = summary_ops.create_file_writer_v2(logdir) 901 with f1.writer.as_default(): 902 summary_ops.write('tag', 1, step=1) 903 @def_function.function 904 def f2(): 905 if not hasattr(f2, 'writer'): 906 f2.writer = summary_ops.create_file_writer_v2(logdir) 907 with f2.writer.as_default(): 908 summary_ops.write('tag', 1, step=2) 909 with context.eager_mode(): 910 f1() 911 event_files = gfile.Glob(os.path.join(logdir, '*')) 912 self.assertEqual(1, len(event_files)) 913 file1 = event_files[0] 914 915 f2() 916 event_files = gfile.Glob(os.path.join(logdir, '*')) 917 self.assertEqual(2, len(event_files)) 918 event_files.remove(file1) 919 file2 = event_files[0] 920 921 # Extra writes to ensure interleaved usage works. 922 f1() 923 f2() 924 925 events = iter(events_from_file(file1)) 926 self.assertEqual('brain.Event:2', next(events).file_version) 927 self.assertEqual(1, next(events).step) 928 self.assertEqual(1, next(events).step) 929 self.assertRaises(StopIteration, lambda: next(events)) 930 events = iter(events_from_file(file2)) 931 self.assertEqual('brain.Event:2', next(events).file_version) 932 self.assertEqual(2, next(events).step) 933 self.assertEqual(2, next(events).step) 934 self.assertRaises(StopIteration, lambda: next(events)) 935 936 def testMaxQueue(self): 937 logdir = self.get_temp_dir() 938 with context.eager_mode(): 939 with summary_ops.create_file_writer_v2( 940 logdir, max_queue=1, flush_millis=999999).as_default(): 941 get_total = lambda: len(events_from_logdir(logdir)) 942 # Note: First tf.compat.v1.Event is always file_version. 943 self.assertEqual(1, get_total()) 944 summary_ops.write('tag', 1, step=0) 945 self.assertEqual(1, get_total()) 946 # Should flush after second summary since max_queue = 1 947 summary_ops.write('tag', 1, step=0) 948 self.assertEqual(3, get_total()) 949 950 def testWriterFlush(self): 951 logdir = self.get_temp_dir() 952 get_total = lambda: len(events_from_logdir(logdir)) 953 with context.eager_mode(): 954 writer = summary_ops.create_file_writer_v2( 955 logdir, max_queue=1000, flush_millis=1000000) 956 self.assertEqual(1, get_total()) # file_version Event 957 with writer.as_default(): 958 summary_ops.write('tag', 1, step=0) 959 self.assertEqual(1, get_total()) 960 writer.flush() 961 self.assertEqual(2, get_total()) 962 summary_ops.write('tag', 1, step=0) 963 self.assertEqual(2, get_total()) 964 # Exiting the "as_default()" should do an implicit flush 965 self.assertEqual(3, get_total()) 966 967 def testFlushFunction(self): 968 logdir = self.get_temp_dir() 969 with context.eager_mode(): 970 writer = summary_ops.create_file_writer_v2( 971 logdir, max_queue=999999, flush_millis=999999) 972 with writer.as_default(): 973 get_total = lambda: len(events_from_logdir(logdir)) 974 # Note: First tf.compat.v1.Event is always file_version. 975 self.assertEqual(1, get_total()) 976 summary_ops.write('tag', 1, step=0) 977 summary_ops.write('tag', 1, step=0) 978 self.assertEqual(1, get_total()) 979 summary_ops.flush() 980 self.assertEqual(3, get_total()) 981 # Test "writer" parameter 982 summary_ops.write('tag', 1, step=0) 983 self.assertEqual(3, get_total()) 984 summary_ops.flush(writer=writer) 985 self.assertEqual(4, get_total()) 986 987 # Regression test for b/228097117. 988 def testFlushFunction_disallowsInvalidWriterInput(self): 989 with context.eager_mode(): 990 with self.assertRaisesRegex(ValueError, 'Invalid argument to flush'): 991 summary_ops.flush(writer=()) 992 993 @test_util.assert_no_new_tensors 994 def testNoMemoryLeak_graphMode(self): 995 logdir = self.get_temp_dir() 996 with context.graph_mode(), ops.Graph().as_default(): 997 summary_ops.create_file_writer_v2(logdir) 998 999 @test_util.assert_no_new_pyobjects_executing_eagerly 1000 def testNoMemoryLeak_eagerMode(self): 1001 logdir = self.get_temp_dir() 1002 with summary_ops.create_file_writer_v2(logdir).as_default(): 1003 summary_ops.write('tag', 1, step=0) 1004 1005 def testClose_preventsLaterUse(self): 1006 logdir = self.get_temp_dir() 1007 with context.eager_mode(): 1008 writer = summary_ops.create_file_writer_v2(logdir) 1009 writer.close() 1010 writer.close() # redundant close() is a no-op 1011 writer.flush() # redundant flush() is a no-op 1012 with self.assertRaisesRegex(RuntimeError, 'already closed'): 1013 writer.init() 1014 with self.assertRaisesRegex(RuntimeError, 'already closed'): 1015 with writer.as_default(): 1016 self.fail('should not get here') 1017 with self.assertRaisesRegex(RuntimeError, 'already closed'): 1018 writer.set_as_default() 1019 1020 def testClose_closesOpenFile(self): 1021 try: 1022 import psutil # pylint: disable=g-import-not-at-top 1023 except ImportError: 1024 raise unittest.SkipTest('test requires psutil') 1025 proc = psutil.Process() 1026 get_open_filenames = lambda: set(info[0] for info in proc.open_files()) 1027 logdir = self.get_temp_dir() 1028 with context.eager_mode(): 1029 writer = summary_ops.create_file_writer_v2(logdir) 1030 files = gfile.Glob(os.path.join(logdir, '*')) 1031 self.assertEqual(1, len(files)) 1032 eventfile = files[0] 1033 self.assertIn(eventfile, get_open_filenames()) 1034 writer.close() 1035 self.assertNotIn(eventfile, get_open_filenames()) 1036 1037 def testDereference_closesOpenFile(self): 1038 try: 1039 import psutil # pylint: disable=g-import-not-at-top 1040 except ImportError: 1041 raise unittest.SkipTest('test requires psutil') 1042 proc = psutil.Process() 1043 get_open_filenames = lambda: set(info[0] for info in proc.open_files()) 1044 logdir = self.get_temp_dir() 1045 with context.eager_mode(): 1046 writer = summary_ops.create_file_writer_v2(logdir) 1047 files = gfile.Glob(os.path.join(logdir, '*')) 1048 self.assertEqual(1, len(files)) 1049 eventfile = files[0] 1050 self.assertIn(eventfile, get_open_filenames()) 1051 del writer 1052 self.assertNotIn(eventfile, get_open_filenames()) 1053 1054 1055class SummaryWriterSavedModelTest(test_util.TensorFlowTestCase): 1056 1057 def testWriter_savedAsModuleProperty_loadInEagerMode(self): 1058 with context.eager_mode(): 1059 class Model(module.Module): 1060 1061 def __init__(self, model_dir): 1062 self._writer = summary_ops.create_file_writer_v2( 1063 model_dir, experimental_trackable=True) 1064 1065 @def_function.function(input_signature=[ 1066 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64) 1067 ]) 1068 def train(self, step): 1069 with self._writer.as_default(): 1070 summary_ops.write('tag', 'foo', step=step) 1071 return constant_op.constant(0) 1072 1073 logdir = self.get_temp_dir() 1074 to_export = Model(logdir) 1075 pre_save_files = set(events_from_multifile_logdir(logdir)) 1076 export_dir = os.path.join(logdir, 'export') 1077 saved_model_save.save( 1078 to_export, export_dir, signatures={'train': to_export.train}) 1079 1080 # Reset context to ensure we don't share any resources with saving code. 1081 context._reset_context() # pylint: disable=protected-access 1082 with context.eager_mode(): 1083 restored = saved_model_load.load(export_dir) 1084 restored.train(1) 1085 restored.train(2) 1086 post_restore_files = set(events_from_multifile_logdir(logdir)) 1087 restored2 = saved_model_load.load(export_dir) 1088 restored2.train(3) 1089 restored2.train(4) 1090 files_to_events = events_from_multifile_logdir(logdir) 1091 post_restore2_files = set(files_to_events) 1092 self.assertLen(files_to_events, 3) 1093 def unwrap_singleton(iterable): 1094 self.assertLen(iterable, 1) 1095 return next(iter(iterable)) 1096 restore_file = unwrap_singleton(post_restore_files - pre_save_files) 1097 restore2_file = unwrap_singleton(post_restore2_files - post_restore_files) 1098 restore_events = files_to_events[restore_file] 1099 restore2_events = files_to_events[restore2_file] 1100 self.assertLen(restore_events, 3) 1101 self.assertEqual(1, restore_events[1].step) 1102 self.assertEqual(2, restore_events[2].step) 1103 self.assertLen(restore2_events, 3) 1104 self.assertEqual(3, restore2_events[1].step) 1105 self.assertEqual(4, restore2_events[2].step) 1106 1107 def testWriter_savedAsModuleProperty_loadInGraphMode(self): 1108 with context.eager_mode(): 1109 1110 class Model(module.Module): 1111 1112 def __init__(self, model_dir): 1113 self._writer = summary_ops.create_file_writer_v2( 1114 model_dir, experimental_trackable=True) 1115 1116 @def_function.function(input_signature=[ 1117 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64) 1118 ]) 1119 def train(self, step): 1120 with self._writer.as_default(): 1121 summary_ops.write('tag', 'foo', step=step) 1122 return constant_op.constant(0) 1123 1124 logdir = self.get_temp_dir() 1125 to_export = Model(logdir) 1126 pre_save_files = set(events_from_multifile_logdir(logdir)) 1127 export_dir = os.path.join(logdir, 'export') 1128 saved_model_save.save( 1129 to_export, export_dir, signatures={'train': to_export.train}) 1130 1131 # Reset context to ensure we don't share any resources with saving code. 1132 context._reset_context() # pylint: disable=protected-access 1133 1134 def load_and_run_model(sess, input_values): 1135 """Load and run the SavedModel signature in the TF 1.x style.""" 1136 model = saved_model_loader.load(sess, [tag_constants.SERVING], export_dir) 1137 signature = model.signature_def['train'] 1138 inputs = list(signature.inputs.values()) 1139 assert len(inputs) == 1, inputs 1140 outputs = list(signature.outputs.values()) 1141 assert len(outputs) == 1, outputs 1142 input_tensor = sess.graph.get_tensor_by_name(inputs[0].name) 1143 output_tensor = sess.graph.get_tensor_by_name(outputs[0].name) 1144 for v in input_values: 1145 sess.run(output_tensor, feed_dict={input_tensor: v}) 1146 1147 with context.graph_mode(), ops.Graph().as_default(): 1148 # Since writer shared_name is fixed, within a single session, all loads of 1149 # this SavedModel will refer to a single writer resouce, so it will be 1150 # initialized only once and write to a single file. 1151 with self.session() as sess: 1152 load_and_run_model(sess, [1, 2]) 1153 load_and_run_model(sess, [3, 4]) 1154 post_restore_files = set(events_from_multifile_logdir(logdir)) 1155 # New session will recreate the resource and write to a second file. 1156 with self.session() as sess: 1157 load_and_run_model(sess, [5, 6]) 1158 files_to_events = events_from_multifile_logdir(logdir) 1159 post_restore2_files = set(files_to_events) 1160 1161 self.assertLen(files_to_events, 3) 1162 def unwrap_singleton(iterable): 1163 self.assertLen(iterable, 1) 1164 return next(iter(iterable)) 1165 restore_file = unwrap_singleton(post_restore_files - pre_save_files) 1166 restore2_file = unwrap_singleton(post_restore2_files - post_restore_files) 1167 restore_events = files_to_events[restore_file] 1168 restore2_events = files_to_events[restore2_file] 1169 self.assertLen(restore_events, 5) 1170 self.assertEqual(1, restore_events[1].step) 1171 self.assertEqual(2, restore_events[2].step) 1172 self.assertEqual(3, restore_events[3].step) 1173 self.assertEqual(4, restore_events[4].step) 1174 self.assertLen(restore2_events, 3) 1175 self.assertEqual(5, restore2_events[1].step) 1176 self.assertEqual(6, restore2_events[2].step) 1177 1178 1179class NoopWriterTest(test_util.TensorFlowTestCase): 1180 1181 def testNoopWriter_doesNothing(self): 1182 logdir = self.get_temp_dir() 1183 with context.eager_mode(): 1184 writer = summary_ops.create_noop_writer() 1185 writer.init() 1186 with writer.as_default(): 1187 result = summary_ops.write('test', 1.0, step=0) 1188 writer.flush() 1189 writer.close() 1190 self.assertFalse(result) # Should have found no active writer 1191 files = gfile.Glob(os.path.join(logdir, '*')) 1192 self.assertLen(files, 0) 1193 1194 def testNoopWriter_asNestedContext_isTransparent(self): 1195 logdir = self.get_temp_dir() 1196 with context.eager_mode(): 1197 writer = summary_ops.create_file_writer_v2(logdir) 1198 noop_writer = summary_ops.create_noop_writer() 1199 with writer.as_default(): 1200 result1 = summary_ops.write('first', 1.0, step=0) 1201 with noop_writer.as_default(): 1202 result2 = summary_ops.write('second', 1.0, step=0) 1203 result3 = summary_ops.write('third', 1.0, step=0) 1204 # All ops should have written, including the one inside the no-op writer, 1205 # since it doesn't actively *disable* writing - it just behaves as if that 1206 # entire `with` block wasn't there at all. 1207 self.assertAllEqual([result1, result2, result3], [True, True, True]) 1208 1209 def testNoopWriter_setAsDefault(self): 1210 try: 1211 with context.eager_mode(): 1212 writer = summary_ops.create_noop_writer() 1213 writer.set_as_default() 1214 result = summary_ops.write('test', 1.0, step=0) 1215 self.assertFalse(result) # Should have found no active writer 1216 finally: 1217 # Ensure we clean up no matter how the test executes. 1218 summary_ops._summary_state.writer = None # pylint: disable=protected-access 1219 1220 1221class SummaryOpsTest(test_util.TensorFlowTestCase): 1222 1223 def tearDown(self): 1224 summary_ops.trace_off() 1225 super().tearDown() 1226 1227 def exec_summary_op(self, summary_op_fn): 1228 assert context.executing_eagerly() 1229 logdir = self.get_temp_dir() 1230 writer = summary_ops.create_file_writer_v2(logdir) 1231 with writer.as_default(): 1232 summary_op_fn() 1233 writer.close() 1234 events = events_from_logdir(logdir) 1235 return events[1] 1236 1237 def run_metadata(self, *args, **kwargs): 1238 assert context.executing_eagerly() 1239 logdir = self.get_temp_dir() 1240 writer = summary_ops.create_file_writer_v2(logdir) 1241 with writer.as_default(): 1242 summary_ops.run_metadata(*args, **kwargs) 1243 writer.close() 1244 events = events_from_logdir(logdir) 1245 return events[1] 1246 1247 def run_metadata_graphs(self, *args, **kwargs): 1248 assert context.executing_eagerly() 1249 logdir = self.get_temp_dir() 1250 writer = summary_ops.create_file_writer_v2(logdir) 1251 with writer.as_default(): 1252 summary_ops.run_metadata_graphs(*args, **kwargs) 1253 writer.close() 1254 events = events_from_logdir(logdir) 1255 return events[1] 1256 1257 def create_run_metadata(self): 1258 step_stats = step_stats_pb2.StepStats(dev_stats=[ 1259 step_stats_pb2.DeviceStepStats( 1260 device='cpu:0', 1261 node_stats=[step_stats_pb2.NodeExecStats(node_name='hello')]) 1262 ]) 1263 return config_pb2.RunMetadata( 1264 function_graphs=[ 1265 config_pb2.RunMetadata.FunctionGraphs( 1266 pre_optimization_graph=graph_pb2.GraphDef( 1267 node=[node_def_pb2.NodeDef(name='foo')])) 1268 ], 1269 step_stats=step_stats) 1270 1271 def run_trace(self, f, step=1): 1272 assert context.executing_eagerly() 1273 logdir = self.get_temp_dir() 1274 writer = summary_ops.create_file_writer_v2(logdir) 1275 summary_ops.trace_on(graph=True, profiler=False) 1276 with writer.as_default(): 1277 f() 1278 summary_ops.trace_export(name='foo', step=step) 1279 writer.close() 1280 events = events_from_logdir(logdir) 1281 return events[1] 1282 1283 @test_util.run_v2_only 1284 def testRunMetadata_usesNameAsTag(self): 1285 meta = config_pb2.RunMetadata() 1286 1287 with ops.name_scope('foo', skip_on_eager=False): 1288 event = self.run_metadata(name='my_name', data=meta, step=1) 1289 first_val = event.summary.value[0] 1290 1291 self.assertEqual('foo/my_name', first_val.tag) 1292 1293 @test_util.run_v2_only 1294 def testRunMetadata_summaryMetadata(self): 1295 expected_summary_metadata = """ 1296 plugin_data { 1297 plugin_name: "graph_run_metadata" 1298 content: "1" 1299 } 1300 """ 1301 meta = config_pb2.RunMetadata() 1302 event = self.run_metadata(name='my_name', data=meta, step=1) 1303 actual_summary_metadata = event.summary.value[0].metadata 1304 self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata) 1305 1306 @test_util.run_v2_only 1307 def testRunMetadata_wholeRunMetadata(self): 1308 expected_run_metadata = """ 1309 step_stats { 1310 dev_stats { 1311 device: "cpu:0" 1312 node_stats { 1313 node_name: "hello" 1314 } 1315 } 1316 } 1317 function_graphs { 1318 pre_optimization_graph { 1319 node { 1320 name: "foo" 1321 } 1322 } 1323 } 1324 """ 1325 meta = self.create_run_metadata() 1326 event = self.run_metadata(name='my_name', data=meta, step=1) 1327 first_val = event.summary.value[0] 1328 1329 actual_run_metadata = config_pb2.RunMetadata.FromString( 1330 first_val.tensor.string_val[0]) 1331 self.assertProtoEquals(expected_run_metadata, actual_run_metadata) 1332 1333 @test_util.run_v2_only 1334 def testRunMetadata_usesDefaultStep(self): 1335 meta = config_pb2.RunMetadata() 1336 try: 1337 summary_ops.set_step(42) 1338 event = self.run_metadata(name='my_name', data=meta) 1339 self.assertEqual(42, event.step) 1340 finally: 1341 # Reset to default state for other tests. 1342 summary_ops.set_step(None) 1343 1344 @test_util.run_v2_only 1345 def testRunMetadataGraph_usesNameAsTag(self): 1346 meta = config_pb2.RunMetadata() 1347 1348 with ops.name_scope('foo', skip_on_eager=False): 1349 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 1350 first_val = event.summary.value[0] 1351 1352 self.assertEqual('foo/my_name', first_val.tag) 1353 1354 @test_util.run_v2_only 1355 def testRunMetadataGraph_summaryMetadata(self): 1356 expected_summary_metadata = """ 1357 plugin_data { 1358 plugin_name: "graph_run_metadata_graph" 1359 content: "1" 1360 } 1361 """ 1362 meta = config_pb2.RunMetadata() 1363 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 1364 actual_summary_metadata = event.summary.value[0].metadata 1365 self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata) 1366 1367 @test_util.run_v2_only 1368 def testRunMetadataGraph_runMetadataFragment(self): 1369 expected_run_metadata = """ 1370 function_graphs { 1371 pre_optimization_graph { 1372 node { 1373 name: "foo" 1374 } 1375 } 1376 } 1377 """ 1378 meta = self.create_run_metadata() 1379 1380 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 1381 first_val = event.summary.value[0] 1382 1383 actual_run_metadata = config_pb2.RunMetadata.FromString( 1384 first_val.tensor.string_val[0]) 1385 self.assertProtoEquals(expected_run_metadata, actual_run_metadata) 1386 1387 @test_util.run_v2_only 1388 def testRunMetadataGraph_usesDefaultStep(self): 1389 meta = config_pb2.RunMetadata() 1390 try: 1391 summary_ops.set_step(42) 1392 event = self.run_metadata_graphs(name='my_name', data=meta) 1393 self.assertEqual(42, event.step) 1394 finally: 1395 # Reset to default state for other tests. 1396 summary_ops.set_step(None) 1397 1398 @test_util.run_v2_only 1399 def testTrace(self): 1400 1401 @def_function.function 1402 def f(): 1403 x = constant_op.constant(2) 1404 y = constant_op.constant(3) 1405 return x**y 1406 1407 event = self.run_trace(f) 1408 1409 first_val = event.summary.value[0] 1410 actual_run_metadata = config_pb2.RunMetadata.FromString( 1411 first_val.tensor.string_val[0]) 1412 1413 # Content of function_graphs is large and, for instance, device can change. 1414 self.assertTrue(hasattr(actual_run_metadata, 'function_graphs')) 1415 1416 @test_util.run_v2_only 1417 def testTrace_cannotEnableTraceInFunction(self): 1418 1419 @def_function.function 1420 def f(): 1421 summary_ops.trace_on(graph=True, profiler=False) 1422 x = constant_op.constant(2) 1423 y = constant_op.constant(3) 1424 return x**y 1425 1426 with test.mock.patch.object(logging, 'warn') as mock_log: 1427 f() 1428 self.assertRegex( 1429 str(mock_log.call_args), 'Cannot enable trace inside a tf.function.') 1430 1431 @test_util.run_v2_only 1432 def testTrace_cannotEnableTraceInGraphMode(self): 1433 with test.mock.patch.object(logging, 'warn') as mock_log: 1434 with context.graph_mode(): 1435 summary_ops.trace_on(graph=True, profiler=False) 1436 self.assertRegex( 1437 str(mock_log.call_args), 'Must enable trace in eager mode.') 1438 1439 @test_util.run_v2_only 1440 def testTrace_cannotExportTraceWithoutTrace(self): 1441 with self.assertRaisesRegex(ValueError, 'Must enable trace before export.'): 1442 summary_ops.trace_export(name='foo', step=1) 1443 1444 @test_util.run_v2_only 1445 def testTrace_cannotExportTraceInFunction(self): 1446 summary_ops.trace_on(graph=True, profiler=False) 1447 1448 @def_function.function 1449 def f(): 1450 x = constant_op.constant(2) 1451 y = constant_op.constant(3) 1452 summary_ops.trace_export(name='foo', step=1) 1453 return x**y 1454 1455 with test.mock.patch.object(logging, 'warn') as mock_log: 1456 f() 1457 self.assertRegex( 1458 str(mock_log.call_args), 'Cannot export trace inside a tf.function.') 1459 1460 @test_util.run_v2_only 1461 def testTrace_cannotExportTraceInGraphMode(self): 1462 with test.mock.patch.object(logging, 'warn') as mock_log: 1463 with context.graph_mode(): 1464 summary_ops.trace_export(name='foo', step=1) 1465 self.assertRegex( 1466 str(mock_log.call_args), 1467 'Can only export trace while executing eagerly.') 1468 1469 @test_util.run_v2_only 1470 def testTrace_usesDefaultStep(self): 1471 1472 @def_function.function 1473 def f(): 1474 x = constant_op.constant(2) 1475 y = constant_op.constant(3) 1476 return x**y 1477 1478 try: 1479 summary_ops.set_step(42) 1480 event = self.run_trace(f, step=None) 1481 self.assertEqual(42, event.step) 1482 finally: 1483 # Reset to default state for other tests. 1484 summary_ops.set_step(None) 1485 1486 @test_util.run_v2_only 1487 def testTrace_withProfiler(self): 1488 1489 @def_function.function 1490 def f(): 1491 x = constant_op.constant(2) 1492 y = constant_op.constant(3) 1493 return x**y 1494 1495 assert context.executing_eagerly() 1496 logdir = self.get_temp_dir() 1497 writer = summary_ops.create_file_writer_v2(logdir) 1498 summary_ops.trace_on(graph=True, profiler=True) 1499 profiler_outdir = self.get_temp_dir() 1500 with writer.as_default(): 1501 f() 1502 summary_ops.trace_export( 1503 name='foo', step=1, profiler_outdir=profiler_outdir) 1504 writer.close() 1505 1506 @test_util.run_v2_only 1507 def testGraph_graph(self): 1508 1509 @def_function.function 1510 def f(): 1511 x = constant_op.constant(2) 1512 y = constant_op.constant(3) 1513 return x**y 1514 1515 def summary_op_fn(): 1516 summary_ops.graph(f.get_concrete_function().graph) 1517 1518 event = self.exec_summary_op(summary_op_fn) 1519 self.assertIsNotNone(event.graph_def) 1520 1521 @test_util.run_v2_only 1522 def testGraph_graphDef(self): 1523 1524 @def_function.function 1525 def f(): 1526 x = constant_op.constant(2) 1527 y = constant_op.constant(3) 1528 return x**y 1529 1530 def summary_op_fn(): 1531 summary_ops.graph(f.get_concrete_function().graph.as_graph_def()) 1532 1533 event = self.exec_summary_op(summary_op_fn) 1534 self.assertIsNotNone(event.graph_def) 1535 1536 @test_util.run_v2_only 1537 def testGraph_invalidData(self): 1538 def summary_op_fn(): 1539 summary_ops.graph('hello') 1540 1541 with self.assertRaisesRegex( 1542 ValueError, 1543 r'\'graph_data\' is not tf.Graph or tf.compat.v1.GraphDef', 1544 ): 1545 self.exec_summary_op(summary_op_fn) 1546 1547 @test_util.run_v2_only 1548 def testGraph_fromGraphMode(self): 1549 1550 @def_function.function 1551 def f(): 1552 x = constant_op.constant(2) 1553 y = constant_op.constant(3) 1554 return x**y 1555 1556 @def_function.function 1557 def g(graph): 1558 summary_ops.graph(graph) 1559 1560 def summary_op_fn(): 1561 graph_def = f.get_concrete_function().graph.as_graph_def(add_shapes=True) 1562 func_graph = constant_op.constant(graph_def.SerializeToString()) 1563 g(func_graph) 1564 1565 with self.assertRaisesRegex( 1566 ValueError, 1567 r'graph\(\) cannot be invoked inside a graph context.', 1568 ): 1569 self.exec_summary_op(summary_op_fn) 1570 1571 1572def events_from_file(filepath): 1573 """Returns all events in a single event file. 1574 1575 Args: 1576 filepath: Path to the event file. 1577 1578 Returns: 1579 A list of all tf.Event protos in the event file. 1580 """ 1581 records = list(tf_record.tf_record_iterator(filepath)) 1582 result = [] 1583 for r in records: 1584 event = event_pb2.Event() 1585 event.ParseFromString(r) 1586 result.append(event) 1587 return result 1588 1589 1590def events_from_logdir(logdir): 1591 """Returns all events in the single eventfile in logdir. 1592 1593 Args: 1594 logdir: The directory in which the single event file is sought. 1595 1596 Returns: 1597 A list of all tf.Event protos from the single event file. 1598 1599 Raises: 1600 AssertionError: If logdir does not contain exactly one file. 1601 """ 1602 assert gfile.Exists(logdir) 1603 files = gfile.ListDirectory(logdir) 1604 assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files 1605 return events_from_file(os.path.join(logdir, files[0])) 1606 1607 1608def events_from_multifile_logdir(logdir): 1609 """Returns map of filename to events for all `tfevents` files in the logdir. 1610 1611 Args: 1612 logdir: The directory from which to load events. 1613 1614 Returns: 1615 A dict mapping from relative filenames to lists of tf.Event protos. 1616 1617 Raises: 1618 AssertionError: If logdir does not contain exactly one file. 1619 """ 1620 assert gfile.Exists(logdir) 1621 files = [file for file in gfile.ListDirectory(logdir) if 'tfevents' in file] 1622 return {file: events_from_file(os.path.join(logdir, file)) for file in files} 1623 1624 1625def to_numpy(summary_value): 1626 return tensor_util.MakeNdarray(summary_value.tensor) 1627 1628 1629if __name__ == '__main__': 1630 test.main() 1631