1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for V2 summary ops from summary_ops_v2.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import unittest 23 24import six 25 26from tensorflow.core.framework import graph_pb2 27from tensorflow.core.framework import node_def_pb2 28from tensorflow.core.framework import step_stats_pb2 29from tensorflow.core.framework import summary_pb2 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.util import event_pb2 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import tensor_util 40from tensorflow.python.framework import test_util 41from tensorflow.python.lib.io import tf_record 42from tensorflow.python.module import module 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import summary_ops_v2 as summary_ops 45from tensorflow.python.ops import variables 46from tensorflow.python.platform import gfile 47from tensorflow.python.platform import test 48from tensorflow.python.platform import tf_logging as logging 49from tensorflow.python.saved_model import load as saved_model_load 50from tensorflow.python.saved_model import loader as saved_model_loader 51from tensorflow.python.saved_model import save as saved_model_save 52from tensorflow.python.saved_model import tag_constants 53 54 55class SummaryOpsCoreTest(test_util.TensorFlowTestCase): 56 57 def testWrite(self): 58 logdir = self.get_temp_dir() 59 with context.eager_mode(): 60 with summary_ops.create_file_writer_v2(logdir).as_default(): 61 output = summary_ops.write('tag', 42, step=12) 62 self.assertTrue(output.numpy()) 63 events = events_from_logdir(logdir) 64 self.assertEqual(2, len(events)) 65 self.assertEqual(12, events[1].step) 66 value = events[1].summary.value[0] 67 self.assertEqual('tag', value.tag) 68 self.assertEqual(42, to_numpy(value)) 69 70 def testWrite_fromFunction(self): 71 logdir = self.get_temp_dir() 72 with context.eager_mode(): 73 writer = summary_ops.create_file_writer_v2(logdir) 74 @def_function.function 75 def f(): 76 with writer.as_default(): 77 return summary_ops.write('tag', 42, step=12) 78 output = f() 79 self.assertTrue(output.numpy()) 80 events = events_from_logdir(logdir) 81 self.assertEqual(2, len(events)) 82 self.assertEqual(12, events[1].step) 83 value = events[1].summary.value[0] 84 self.assertEqual('tag', value.tag) 85 self.assertEqual(42, to_numpy(value)) 86 87 def testWrite_metadata(self): 88 logdir = self.get_temp_dir() 89 metadata = summary_pb2.SummaryMetadata() 90 metadata.plugin_data.plugin_name = 'foo' 91 with context.eager_mode(): 92 with summary_ops.create_file_writer_v2(logdir).as_default(): 93 summary_ops.write('obj', 0, 0, metadata=metadata) 94 summary_ops.write('bytes', 0, 0, metadata=metadata.SerializeToString()) 95 m = constant_op.constant(metadata.SerializeToString()) 96 summary_ops.write('string_tensor', 0, 0, metadata=m) 97 events = events_from_logdir(logdir) 98 self.assertEqual(4, len(events)) 99 self.assertEqual(metadata, events[1].summary.value[0].metadata) 100 self.assertEqual(metadata, events[2].summary.value[0].metadata) 101 self.assertEqual(metadata, events[3].summary.value[0].metadata) 102 103 def testWrite_name(self): 104 @def_function.function 105 def f(): 106 output = summary_ops.write('tag', 42, step=12, name='anonymous') 107 self.assertTrue(output.name.startswith('anonymous')) 108 f() 109 110 def testWrite_ndarray(self): 111 logdir = self.get_temp_dir() 112 with context.eager_mode(): 113 with summary_ops.create_file_writer_v2(logdir).as_default(): 114 summary_ops.write('tag', [[1, 2], [3, 4]], step=12) 115 events = events_from_logdir(logdir) 116 value = events[1].summary.value[0] 117 self.assertAllEqual([[1, 2], [3, 4]], to_numpy(value)) 118 119 def testWrite_tensor(self): 120 logdir = self.get_temp_dir() 121 with context.eager_mode(): 122 t = constant_op.constant([[1, 2], [3, 4]]) 123 with summary_ops.create_file_writer_v2(logdir).as_default(): 124 summary_ops.write('tag', t, step=12) 125 expected = t.numpy() 126 events = events_from_logdir(logdir) 127 value = events[1].summary.value[0] 128 self.assertAllEqual(expected, to_numpy(value)) 129 130 def testWrite_tensor_fromFunction(self): 131 logdir = self.get_temp_dir() 132 with context.eager_mode(): 133 writer = summary_ops.create_file_writer_v2(logdir) 134 @def_function.function 135 def f(t): 136 with writer.as_default(): 137 summary_ops.write('tag', t, step=12) 138 t = constant_op.constant([[1, 2], [3, 4]]) 139 f(t) 140 expected = t.numpy() 141 events = events_from_logdir(logdir) 142 value = events[1].summary.value[0] 143 self.assertAllEqual(expected, to_numpy(value)) 144 145 def testWrite_stringTensor(self): 146 logdir = self.get_temp_dir() 147 with context.eager_mode(): 148 with summary_ops.create_file_writer_v2(logdir).as_default(): 149 summary_ops.write('tag', [b'foo', b'bar'], step=12) 150 events = events_from_logdir(logdir) 151 value = events[1].summary.value[0] 152 self.assertAllEqual([b'foo', b'bar'], to_numpy(value)) 153 154 @test_util.run_gpu_only 155 def testWrite_gpuDeviceContext(self): 156 logdir = self.get_temp_dir() 157 with context.eager_mode(): 158 with summary_ops.create_file_writer_v2(logdir).as_default(): 159 with ops.device('/GPU:0'): 160 value = constant_op.constant(42.0) 161 step = constant_op.constant(12, dtype=dtypes.int64) 162 summary_ops.write('tag', value, step=step).numpy() 163 empty_metadata = summary_pb2.SummaryMetadata() 164 events = events_from_logdir(logdir) 165 self.assertEqual(2, len(events)) 166 self.assertEqual(12, events[1].step) 167 self.assertEqual(42, to_numpy(events[1].summary.value[0])) 168 self.assertEqual(empty_metadata, events[1].summary.value[0].metadata) 169 170 @test_util.also_run_as_tf_function 171 def testWrite_noDefaultWriter(self): 172 # Use assertAllEqual instead of assertFalse since it works in a defun. 173 self.assertAllEqual(False, summary_ops.write('tag', 42, step=0)) 174 175 @test_util.also_run_as_tf_function 176 def testWrite_noStep_okayIfAlsoNoDefaultWriter(self): 177 # Use assertAllEqual instead of assertFalse since it works in a defun. 178 self.assertAllEqual(False, summary_ops.write('tag', 42)) 179 180 def testWrite_noStep(self): 181 logdir = self.get_temp_dir() 182 with context.eager_mode(): 183 with summary_ops.create_file_writer_v2(logdir).as_default(): 184 with self.assertRaisesRegex(ValueError, 'No step set'): 185 summary_ops.write('tag', 42) 186 187 def testWrite_noStep_okayIfNotRecordingSummaries(self): 188 logdir = self.get_temp_dir() 189 with context.eager_mode(): 190 with summary_ops.create_file_writer_v2(logdir).as_default(): 191 with summary_ops.record_if(False): 192 self.assertFalse(summary_ops.write('tag', 42)) 193 194 def testWrite_usingDefaultStep(self): 195 logdir = self.get_temp_dir() 196 try: 197 with context.eager_mode(): 198 with summary_ops.create_file_writer_v2(logdir).as_default(): 199 summary_ops.set_step(1) 200 summary_ops.write('tag', 1.0) 201 summary_ops.set_step(2) 202 summary_ops.write('tag', 1.0) 203 mystep = variables.Variable(10, dtype=dtypes.int64) 204 summary_ops.set_step(mystep) 205 summary_ops.write('tag', 1.0) 206 mystep.assign_add(1) 207 summary_ops.write('tag', 1.0) 208 events = events_from_logdir(logdir) 209 self.assertEqual(5, len(events)) 210 self.assertEqual(1, events[1].step) 211 self.assertEqual(2, events[2].step) 212 self.assertEqual(10, events[3].step) 213 self.assertEqual(11, events[4].step) 214 finally: 215 # Reset to default state for other tests. 216 summary_ops.set_step(None) 217 218 def testWrite_usingDefaultStepConstant_fromFunction(self): 219 logdir = self.get_temp_dir() 220 try: 221 with context.eager_mode(): 222 writer = summary_ops.create_file_writer_v2(logdir) 223 @def_function.function 224 def f(): 225 with writer.as_default(): 226 summary_ops.write('tag', 1.0) 227 summary_ops.set_step(1) 228 f() 229 summary_ops.set_step(2) 230 f() 231 events = events_from_logdir(logdir) 232 self.assertEqual(3, len(events)) 233 self.assertEqual(1, events[1].step) 234 # The step value will still be 1 because the value was captured at the 235 # time the function was first traced. 236 self.assertEqual(1, events[2].step) 237 finally: 238 # Reset to default state for other tests. 239 summary_ops.set_step(None) 240 241 def testWrite_usingDefaultStepVariable_fromFunction(self): 242 logdir = self.get_temp_dir() 243 try: 244 with context.eager_mode(): 245 writer = summary_ops.create_file_writer_v2(logdir) 246 @def_function.function 247 def f(): 248 with writer.as_default(): 249 summary_ops.write('tag', 1.0) 250 mystep = variables.Variable(0, dtype=dtypes.int64) 251 summary_ops.set_step(mystep) 252 f() 253 mystep.assign_add(1) 254 f() 255 mystep.assign(10) 256 f() 257 events = events_from_logdir(logdir) 258 self.assertEqual(4, len(events)) 259 self.assertEqual(0, events[1].step) 260 self.assertEqual(1, events[2].step) 261 self.assertEqual(10, events[3].step) 262 finally: 263 # Reset to default state for other tests. 264 summary_ops.set_step(None) 265 266 def testWrite_usingDefaultStepConstant_fromLegacyGraph(self): 267 logdir = self.get_temp_dir() 268 try: 269 with context.graph_mode(): 270 writer = summary_ops.create_file_writer_v2(logdir) 271 summary_ops.set_step(1) 272 with writer.as_default(): 273 write_op = summary_ops.write('tag', 1.0) 274 summary_ops.set_step(2) 275 with self.cached_session() as sess: 276 sess.run(writer.init()) 277 sess.run(write_op) 278 sess.run(write_op) 279 sess.run(writer.flush()) 280 events = events_from_logdir(logdir) 281 self.assertEqual(3, len(events)) 282 self.assertEqual(1, events[1].step) 283 # The step value will still be 1 because the value was captured at the 284 # time the graph was constructed. 285 self.assertEqual(1, events[2].step) 286 finally: 287 # Reset to default state for other tests. 288 summary_ops.set_step(None) 289 290 def testWrite_usingDefaultStepVariable_fromLegacyGraph(self): 291 logdir = self.get_temp_dir() 292 try: 293 with context.graph_mode(): 294 writer = summary_ops.create_file_writer_v2(logdir) 295 mystep = variables.Variable(0, dtype=dtypes.int64) 296 summary_ops.set_step(mystep) 297 with writer.as_default(): 298 write_op = summary_ops.write('tag', 1.0) 299 first_assign_op = mystep.assign_add(1) 300 second_assign_op = mystep.assign(10) 301 with self.cached_session() as sess: 302 sess.run(writer.init()) 303 sess.run(mystep.initializer) 304 sess.run(write_op) 305 sess.run(first_assign_op) 306 sess.run(write_op) 307 sess.run(second_assign_op) 308 sess.run(write_op) 309 sess.run(writer.flush()) 310 events = events_from_logdir(logdir) 311 self.assertEqual(4, len(events)) 312 self.assertEqual(0, events[1].step) 313 self.assertEqual(1, events[2].step) 314 self.assertEqual(10, events[3].step) 315 finally: 316 # Reset to default state for other tests. 317 summary_ops.set_step(None) 318 319 def testWrite_usingDefaultStep_fromAsDefault(self): 320 logdir = self.get_temp_dir() 321 try: 322 with context.eager_mode(): 323 writer = summary_ops.create_file_writer_v2(logdir) 324 with writer.as_default(step=1): 325 summary_ops.write('tag', 1.0) 326 with writer.as_default(): 327 summary_ops.write('tag', 1.0) 328 with writer.as_default(step=2): 329 summary_ops.write('tag', 1.0) 330 summary_ops.write('tag', 1.0) 331 summary_ops.set_step(3) 332 summary_ops.write('tag', 1.0) 333 events = events_from_logdir(logdir) 334 self.assertListEqual([1, 1, 2, 1, 3], [e.step for e in events[1:]]) 335 finally: 336 # Reset to default state for other tests. 337 summary_ops.set_step(None) 338 339 def testWrite_usingDefaultStepVariable_fromAsDefault(self): 340 logdir = self.get_temp_dir() 341 try: 342 with context.eager_mode(): 343 writer = summary_ops.create_file_writer_v2(logdir) 344 mystep = variables.Variable(1, dtype=dtypes.int64) 345 with writer.as_default(step=mystep): 346 summary_ops.write('tag', 1.0) 347 with writer.as_default(): 348 mystep.assign(2) 349 summary_ops.write('tag', 1.0) 350 with writer.as_default(step=3): 351 summary_ops.write('tag', 1.0) 352 summary_ops.write('tag', 1.0) 353 mystep.assign(4) 354 summary_ops.write('tag', 1.0) 355 events = events_from_logdir(logdir) 356 self.assertListEqual([1, 2, 3, 2, 4], [e.step for e in events[1:]]) 357 finally: 358 # Reset to default state for other tests. 359 summary_ops.set_step(None) 360 361 def testWrite_usingDefaultStep_fromSetAsDefault(self): 362 logdir = self.get_temp_dir() 363 try: 364 with context.eager_mode(): 365 writer = summary_ops.create_file_writer_v2(logdir) 366 mystep = variables.Variable(1, dtype=dtypes.int64) 367 writer.set_as_default(step=mystep) 368 summary_ops.write('tag', 1.0) 369 mystep.assign(2) 370 summary_ops.write('tag', 1.0) 371 writer.set_as_default(step=3) 372 summary_ops.write('tag', 1.0) 373 writer.flush() 374 events = events_from_logdir(logdir) 375 self.assertListEqual([1, 2, 3], [e.step for e in events[1:]]) 376 finally: 377 # Reset to default state for other tests. 378 summary_ops.set_step(None) 379 380 def testWrite_usingDefaultStepVariable_fromSetAsDefault(self): 381 logdir = self.get_temp_dir() 382 try: 383 with context.eager_mode(): 384 writer = summary_ops.create_file_writer_v2(logdir) 385 writer.set_as_default(step=1) 386 summary_ops.write('tag', 1.0) 387 writer.set_as_default(step=2) 388 summary_ops.write('tag', 1.0) 389 writer.set_as_default() 390 summary_ops.write('tag', 1.0) 391 writer.flush() 392 events = events_from_logdir(logdir) 393 self.assertListEqual([1, 2, 2], [e.step for e in events[1:]]) 394 finally: 395 # Reset to default state for other tests. 396 summary_ops.set_step(None) 397 398 def testWrite_recordIf_constant(self): 399 logdir = self.get_temp_dir() 400 with context.eager_mode(): 401 with summary_ops.create_file_writer_v2(logdir).as_default(): 402 self.assertTrue(summary_ops.write('default', 1, step=0)) 403 with summary_ops.record_if(True): 404 self.assertTrue(summary_ops.write('set_on', 1, step=0)) 405 with summary_ops.record_if(False): 406 self.assertFalse(summary_ops.write('set_off', 1, step=0)) 407 events = events_from_logdir(logdir) 408 self.assertEqual(3, len(events)) 409 self.assertEqual('default', events[1].summary.value[0].tag) 410 self.assertEqual('set_on', events[2].summary.value[0].tag) 411 412 def testWrite_recordIf_constant_fromFunction(self): 413 logdir = self.get_temp_dir() 414 with context.eager_mode(): 415 writer = summary_ops.create_file_writer_v2(logdir) 416 @def_function.function 417 def f(): 418 with writer.as_default(): 419 # Use assertAllEqual instead of assertTrue since it works in a defun. 420 self.assertAllEqual(summary_ops.write('default', 1, step=0), True) 421 with summary_ops.record_if(True): 422 self.assertAllEqual(summary_ops.write('set_on', 1, step=0), True) 423 with summary_ops.record_if(False): 424 self.assertAllEqual(summary_ops.write('set_off', 1, step=0), False) 425 f() 426 events = events_from_logdir(logdir) 427 self.assertEqual(3, len(events)) 428 self.assertEqual('default', events[1].summary.value[0].tag) 429 self.assertEqual('set_on', events[2].summary.value[0].tag) 430 431 def testWrite_recordIf_callable(self): 432 logdir = self.get_temp_dir() 433 with context.eager_mode(): 434 step = variables.Variable(-1, dtype=dtypes.int64) 435 def record_fn(): 436 step.assign_add(1) 437 return int(step % 2) == 0 438 with summary_ops.create_file_writer_v2(logdir).as_default(): 439 with summary_ops.record_if(record_fn): 440 self.assertTrue(summary_ops.write('tag', 1, step=step)) 441 self.assertFalse(summary_ops.write('tag', 1, step=step)) 442 self.assertTrue(summary_ops.write('tag', 1, step=step)) 443 self.assertFalse(summary_ops.write('tag', 1, step=step)) 444 self.assertTrue(summary_ops.write('tag', 1, step=step)) 445 events = events_from_logdir(logdir) 446 self.assertEqual(4, len(events)) 447 self.assertEqual(0, events[1].step) 448 self.assertEqual(2, events[2].step) 449 self.assertEqual(4, events[3].step) 450 451 def testWrite_recordIf_callable_fromFunction(self): 452 logdir = self.get_temp_dir() 453 with context.eager_mode(): 454 writer = summary_ops.create_file_writer_v2(logdir) 455 step = variables.Variable(-1, dtype=dtypes.int64) 456 @def_function.function 457 def record_fn(): 458 step.assign_add(1) 459 return math_ops.equal(step % 2, 0) 460 @def_function.function 461 def f(): 462 with writer.as_default(): 463 with summary_ops.record_if(record_fn): 464 return [ 465 summary_ops.write('tag', 1, step=step), 466 summary_ops.write('tag', 1, step=step), 467 summary_ops.write('tag', 1, step=step)] 468 self.assertAllEqual(f(), [True, False, True]) 469 self.assertAllEqual(f(), [False, True, False]) 470 events = events_from_logdir(logdir) 471 self.assertEqual(4, len(events)) 472 self.assertEqual(0, events[1].step) 473 self.assertEqual(2, events[2].step) 474 self.assertEqual(4, events[3].step) 475 476 def testWrite_recordIf_tensorInput_fromFunction(self): 477 logdir = self.get_temp_dir() 478 with context.eager_mode(): 479 writer = summary_ops.create_file_writer_v2(logdir) 480 @def_function.function(input_signature=[ 481 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)]) 482 def f(step): 483 with writer.as_default(): 484 with summary_ops.record_if(math_ops.equal(step % 2, 0)): 485 return summary_ops.write('tag', 1, step=step) 486 self.assertTrue(f(0)) 487 self.assertFalse(f(1)) 488 self.assertTrue(f(2)) 489 self.assertFalse(f(3)) 490 self.assertTrue(f(4)) 491 events = events_from_logdir(logdir) 492 self.assertEqual(4, len(events)) 493 self.assertEqual(0, events[1].step) 494 self.assertEqual(2, events[2].step) 495 self.assertEqual(4, events[3].step) 496 497 def testWriteRawPb(self): 498 logdir = self.get_temp_dir() 499 pb = summary_pb2.Summary() 500 pb.value.add().simple_value = 42.0 501 with context.eager_mode(): 502 with summary_ops.create_file_writer_v2(logdir).as_default(): 503 output = summary_ops.write_raw_pb(pb.SerializeToString(), step=12) 504 self.assertTrue(output.numpy()) 505 events = events_from_logdir(logdir) 506 self.assertEqual(2, len(events)) 507 self.assertEqual(12, events[1].step) 508 self.assertProtoEquals(pb, events[1].summary) 509 510 def testWriteRawPb_fromFunction(self): 511 logdir = self.get_temp_dir() 512 pb = summary_pb2.Summary() 513 pb.value.add().simple_value = 42.0 514 with context.eager_mode(): 515 writer = summary_ops.create_file_writer_v2(logdir) 516 @def_function.function 517 def f(): 518 with writer.as_default(): 519 return summary_ops.write_raw_pb(pb.SerializeToString(), step=12) 520 output = f() 521 self.assertTrue(output.numpy()) 522 events = events_from_logdir(logdir) 523 self.assertEqual(2, len(events)) 524 self.assertEqual(12, events[1].step) 525 self.assertProtoEquals(pb, events[1].summary) 526 527 def testWriteRawPb_multipleValues(self): 528 logdir = self.get_temp_dir() 529 pb1 = summary_pb2.Summary() 530 pb1.value.add().simple_value = 1.0 531 pb1.value.add().simple_value = 2.0 532 pb2 = summary_pb2.Summary() 533 pb2.value.add().simple_value = 3.0 534 pb3 = summary_pb2.Summary() 535 pb3.value.add().simple_value = 4.0 536 pb3.value.add().simple_value = 5.0 537 pb3.value.add().simple_value = 6.0 538 pbs = [pb.SerializeToString() for pb in (pb1, pb2, pb3)] 539 with context.eager_mode(): 540 with summary_ops.create_file_writer_v2(logdir).as_default(): 541 output = summary_ops.write_raw_pb(pbs, step=12) 542 self.assertTrue(output.numpy()) 543 events = events_from_logdir(logdir) 544 self.assertEqual(2, len(events)) 545 self.assertEqual(12, events[1].step) 546 expected_pb = summary_pb2.Summary() 547 for i in range(6): 548 expected_pb.value.add().simple_value = i + 1.0 549 self.assertProtoEquals(expected_pb, events[1].summary) 550 551 def testWriteRawPb_invalidValue(self): 552 logdir = self.get_temp_dir() 553 with context.eager_mode(): 554 with summary_ops.create_file_writer_v2(logdir).as_default(): 555 with self.assertRaisesRegex( 556 errors.DataLossError, 557 'Bad tf.compat.v1.Summary binary proto tensor string'): 558 summary_ops.write_raw_pb('notaproto', step=12) 559 560 @test_util.also_run_as_tf_function 561 def testGetSetStep(self): 562 try: 563 self.assertIsNone(summary_ops.get_step()) 564 summary_ops.set_step(1) 565 # Use assertAllEqual instead of assertEqual since it works in a defun. 566 self.assertAllEqual(1, summary_ops.get_step()) 567 summary_ops.set_step(constant_op.constant(2)) 568 self.assertAllEqual(2, summary_ops.get_step()) 569 finally: 570 # Reset to default state for other tests. 571 summary_ops.set_step(None) 572 573 def testGetSetStep_variable(self): 574 with context.eager_mode(): 575 try: 576 mystep = variables.Variable(0) 577 summary_ops.set_step(mystep) 578 self.assertAllEqual(0, summary_ops.get_step().read_value()) 579 mystep.assign_add(1) 580 self.assertAllEqual(1, summary_ops.get_step().read_value()) 581 # Check that set_step() properly maintains reference to variable. 582 del mystep 583 self.assertAllEqual(1, summary_ops.get_step().read_value()) 584 summary_ops.get_step().assign_add(1) 585 self.assertAllEqual(2, summary_ops.get_step().read_value()) 586 finally: 587 # Reset to default state for other tests. 588 summary_ops.set_step(None) 589 590 def testGetSetStep_variable_fromFunction(self): 591 with context.eager_mode(): 592 try: 593 @def_function.function 594 def set_step(step): 595 summary_ops.set_step(step) 596 return summary_ops.get_step() 597 @def_function.function 598 def get_and_increment(): 599 summary_ops.get_step().assign_add(1) 600 return summary_ops.get_step() 601 mystep = variables.Variable(0) 602 self.assertAllEqual(0, set_step(mystep)) 603 self.assertAllEqual(0, summary_ops.get_step().read_value()) 604 self.assertAllEqual(1, get_and_increment()) 605 self.assertAllEqual(2, get_and_increment()) 606 # Check that set_step() properly maintains reference to variable. 607 del mystep 608 self.assertAllEqual(3, get_and_increment()) 609 finally: 610 # Reset to default state for other tests. 611 summary_ops.set_step(None) 612 613 @test_util.also_run_as_tf_function 614 def testSummaryScope(self): 615 with summary_ops.summary_scope('foo') as (tag, scope): 616 self.assertEqual('foo', tag) 617 self.assertEqual('foo/', scope) 618 with summary_ops.summary_scope('bar') as (tag, scope): 619 self.assertEqual('foo/bar', tag) 620 self.assertEqual('foo/bar/', scope) 621 with summary_ops.summary_scope('with/slash') as (tag, scope): 622 self.assertEqual('foo/with/slash', tag) 623 self.assertEqual('foo/with/slash/', scope) 624 with ops.name_scope(None, skip_on_eager=False): 625 with summary_ops.summary_scope('unnested') as (tag, scope): 626 self.assertEqual('unnested', tag) 627 self.assertEqual('unnested/', scope) 628 629 @test_util.also_run_as_tf_function 630 def testSummaryScope_defaultName(self): 631 with summary_ops.summary_scope(None) as (tag, scope): 632 self.assertEqual('summary', tag) 633 self.assertEqual('summary/', scope) 634 with summary_ops.summary_scope(None, 'backup') as (tag, scope): 635 self.assertEqual('backup', tag) 636 self.assertEqual('backup/', scope) 637 638 @test_util.also_run_as_tf_function 639 def testSummaryScope_handlesCharactersIllegalForScope(self): 640 with summary_ops.summary_scope('f?o?o') as (tag, scope): 641 self.assertEqual('f?o?o', tag) 642 self.assertEqual('foo/', scope) 643 # If all characters aren't legal for a scope name, use default name. 644 with summary_ops.summary_scope('???', 'backup') as (tag, scope): 645 self.assertEqual('???', tag) 646 self.assertEqual('backup/', scope) 647 648 @test_util.also_run_as_tf_function 649 def testSummaryScope_nameNotUniquifiedForTag(self): 650 constant_op.constant(0, name='foo') 651 with summary_ops.summary_scope('foo') as (tag, _): 652 self.assertEqual('foo', tag) 653 with summary_ops.summary_scope('foo') as (tag, _): 654 self.assertEqual('foo', tag) 655 with ops.name_scope('with', skip_on_eager=False): 656 constant_op.constant(0, name='slash') 657 with summary_ops.summary_scope('with/slash') as (tag, _): 658 self.assertEqual('with/slash', tag) 659 660 def testAllV2SummaryOps(self): 661 logdir = self.get_temp_dir() 662 def define_ops(): 663 result = [] 664 # TF 2.0 summary ops 665 result.append(summary_ops.write('write', 1, step=0)) 666 result.append(summary_ops.write_raw_pb(b'', step=0, name='raw_pb')) 667 # TF 1.x tf.contrib.summary ops 668 result.append(summary_ops.generic('tensor', 1, step=1)) 669 result.append(summary_ops.scalar('scalar', 2.0, step=1)) 670 result.append(summary_ops.histogram('histogram', [1.0], step=1)) 671 result.append(summary_ops.image('image', [[[[1.0]]]], step=1)) 672 result.append(summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1)) 673 return result 674 with context.graph_mode(): 675 ops_without_writer = define_ops() 676 with summary_ops.create_file_writer_v2(logdir).as_default(): 677 with summary_ops.record_if(True): 678 ops_recording_on = define_ops() 679 with summary_ops.record_if(False): 680 ops_recording_off = define_ops() 681 # We should be collecting all ops defined with a default writer present, 682 # regardless of whether recording was set on or off, but not those defined 683 # without a writer at all. 684 del ops_without_writer 685 expected_ops = ops_recording_on + ops_recording_off 686 self.assertCountEqual(expected_ops, summary_ops.all_v2_summary_ops()) 687 688 def testShouldRecordSummaries_defaultState(self): 689 logdir = self.get_temp_dir() 690 with context.eager_mode(): 691 self.assertAllEqual(False, summary_ops.should_record_summaries()) 692 w = summary_ops.create_file_writer_v2(logdir) 693 self.assertAllEqual(False, summary_ops.should_record_summaries()) 694 with w.as_default(): 695 # Should be enabled only when default writer is registered. 696 self.assertAllEqual(True, summary_ops.should_record_summaries()) 697 self.assertAllEqual(False, summary_ops.should_record_summaries()) 698 with summary_ops.record_if(True): 699 # Should be disabled when no default writer, even with record_if(True). 700 self.assertAllEqual(False, summary_ops.should_record_summaries()) 701 702 def testShouldRecordSummaries_constants(self): 703 logdir = self.get_temp_dir() 704 with context.eager_mode(): 705 with summary_ops.create_file_writer_v2(logdir).as_default(): 706 with summary_ops.record_if(True): 707 self.assertAllEqual(True, summary_ops.should_record_summaries()) 708 with summary_ops.record_if(False): 709 self.assertAllEqual(False, summary_ops.should_record_summaries()) 710 with summary_ops.record_if(True): 711 self.assertAllEqual(True, summary_ops.should_record_summaries()) 712 713 def testShouldRecordSummaries_variable(self): 714 logdir = self.get_temp_dir() 715 with context.eager_mode(): 716 with summary_ops.create_file_writer_v2(logdir).as_default(): 717 cond = variables.Variable(False) 718 with summary_ops.record_if(cond): 719 self.assertAllEqual(False, summary_ops.should_record_summaries()) 720 cond.assign(True) 721 self.assertAllEqual(True, summary_ops.should_record_summaries()) 722 723 def testShouldRecordSummaries_callable(self): 724 logdir = self.get_temp_dir() 725 with context.eager_mode(): 726 with summary_ops.create_file_writer_v2(logdir).as_default(): 727 cond_box = [False] 728 cond = lambda: cond_box[0] 729 with summary_ops.record_if(cond): 730 self.assertAllEqual(False, summary_ops.should_record_summaries()) 731 cond_box[0] = True 732 self.assertAllEqual(True, summary_ops.should_record_summaries()) 733 734 def testShouldRecordSummaries_fromFunction(self): 735 logdir = self.get_temp_dir() 736 with context.eager_mode(): 737 writer = summary_ops.create_file_writer_v2(logdir) 738 @def_function.function(input_signature=[ 739 tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool)]) 740 def f(cond): 741 results = [] 742 results.append(summary_ops.should_record_summaries()) 743 with writer.as_default(): 744 results.append(summary_ops.should_record_summaries()) 745 with summary_ops.record_if(False): 746 results.append(summary_ops.should_record_summaries()) 747 with summary_ops.record_if(cond): 748 results.append(summary_ops.should_record_summaries()) 749 return results 750 self.assertAllEqual([False, True, False, True], f(True)) 751 self.assertAllEqual([False, True, False, False], f(False)) 752 753 754class SummaryWriterTest(test_util.TensorFlowTestCase): 755 756 def testCreate_withInitAndClose(self): 757 logdir = self.get_temp_dir() 758 with context.eager_mode(): 759 writer = summary_ops.create_file_writer_v2( 760 logdir, max_queue=1000, flush_millis=1000000) 761 get_total = lambda: len(events_from_logdir(logdir)) 762 self.assertEqual(1, get_total()) # file_version Event 763 # Calling init() again while writer is open has no effect 764 writer.init() 765 self.assertEqual(1, get_total()) 766 with writer.as_default(): 767 summary_ops.write('tag', 1, step=0) 768 self.assertEqual(1, get_total()) 769 # Calling .close() should do an implicit flush 770 writer.close() 771 self.assertEqual(2, get_total()) 772 773 def testCreate_fromFunction(self): 774 logdir = self.get_temp_dir() 775 @def_function.function 776 def f(): 777 # Returned SummaryWriter must be stored in a non-local variable so it 778 # lives throughout the function execution. 779 if not hasattr(f, 'writer'): 780 f.writer = summary_ops.create_file_writer_v2(logdir) 781 with context.eager_mode(): 782 f() 783 event_files = gfile.Glob(os.path.join(logdir, '*')) 784 self.assertEqual(1, len(event_files)) 785 786 def testCreate_graphTensorArgument_raisesError(self): 787 logdir = self.get_temp_dir() 788 with context.graph_mode(): 789 logdir_tensor = constant_op.constant(logdir) 790 with context.eager_mode(): 791 with self.assertRaisesRegex( 792 ValueError, 'Invalid graph Tensor argument.*logdir'): 793 summary_ops.create_file_writer_v2(logdir_tensor) 794 self.assertEmpty(gfile.Glob(os.path.join(logdir, '*'))) 795 796 def testCreate_fromFunction_graphTensorArgument_raisesError(self): 797 logdir = self.get_temp_dir() 798 @def_function.function 799 def f(): 800 summary_ops.create_file_writer_v2(constant_op.constant(logdir)) 801 with context.eager_mode(): 802 with self.assertRaisesRegex( 803 ValueError, 'Invalid graph Tensor argument.*logdir'): 804 f() 805 self.assertEmpty(gfile.Glob(os.path.join(logdir, '*'))) 806 807 def testCreate_fromFunction_unpersistedResource_raisesError(self): 808 logdir = self.get_temp_dir() 809 @def_function.function 810 def f(): 811 with summary_ops.create_file_writer_v2(logdir).as_default(): 812 pass # Calling .as_default() is enough to indicate use. 813 with context.eager_mode(): 814 # TODO(nickfelt): change this to a better error 815 with self.assertRaisesRegex( 816 errors.NotFoundError, 'Resource.*does not exist'): 817 f() 818 # Even though we didn't use it, an event file will have been created. 819 self.assertEqual(1, len(gfile.Glob(os.path.join(logdir, '*')))) 820 821 def testCreate_immediateSetAsDefault_retainsReference(self): 822 logdir = self.get_temp_dir() 823 try: 824 with context.eager_mode(): 825 summary_ops.create_file_writer_v2(logdir).set_as_default() 826 summary_ops.flush() 827 finally: 828 # Ensure we clean up no matter how the test executes. 829 summary_ops._summary_state.writer = None # pylint: disable=protected-access 830 831 def testCreate_immediateAsDefault_retainsReference(self): 832 logdir = self.get_temp_dir() 833 with context.eager_mode(): 834 with summary_ops.create_file_writer_v2(logdir).as_default(): 835 summary_ops.flush() 836 837 def testCreate_avoidsFilenameCollision(self): 838 logdir = self.get_temp_dir() 839 with context.eager_mode(): 840 for _ in range(10): 841 summary_ops.create_file_writer_v2(logdir) 842 event_files = gfile.Glob(os.path.join(logdir, '*')) 843 self.assertLen(event_files, 10) 844 845 def testCreate_graphMode_avoidsFilenameCollision(self): 846 logdir = self.get_temp_dir() 847 with context.graph_mode(), ops.Graph().as_default(): 848 writer = summary_ops.create_file_writer_v2(logdir) 849 with self.cached_session() as sess: 850 for _ in range(10): 851 sess.run(writer.init()) 852 sess.run(writer.close()) 853 event_files = gfile.Glob(os.path.join(logdir, '*')) 854 self.assertLen(event_files, 10) 855 856 def testNoSharing(self): 857 # Two writers with the same logdir should not share state. 858 logdir = self.get_temp_dir() 859 with context.eager_mode(): 860 writer1 = summary_ops.create_file_writer_v2(logdir) 861 with writer1.as_default(): 862 summary_ops.write('tag', 1, step=1) 863 event_files = gfile.Glob(os.path.join(logdir, '*')) 864 self.assertEqual(1, len(event_files)) 865 file1 = event_files[0] 866 867 writer2 = summary_ops.create_file_writer_v2(logdir) 868 with writer2.as_default(): 869 summary_ops.write('tag', 1, step=2) 870 event_files = gfile.Glob(os.path.join(logdir, '*')) 871 self.assertEqual(2, len(event_files)) 872 event_files.remove(file1) 873 file2 = event_files[0] 874 875 # Extra writes to ensure interleaved usage works. 876 with writer1.as_default(): 877 summary_ops.write('tag', 1, step=1) 878 with writer2.as_default(): 879 summary_ops.write('tag', 1, step=2) 880 881 events = iter(events_from_file(file1)) 882 self.assertEqual('brain.Event:2', next(events).file_version) 883 self.assertEqual(1, next(events).step) 884 self.assertEqual(1, next(events).step) 885 self.assertRaises(StopIteration, lambda: next(events)) 886 events = iter(events_from_file(file2)) 887 self.assertEqual('brain.Event:2', next(events).file_version) 888 self.assertEqual(2, next(events).step) 889 self.assertEqual(2, next(events).step) 890 self.assertRaises(StopIteration, lambda: next(events)) 891 892 def testNoSharing_fromFunction(self): 893 logdir = self.get_temp_dir() 894 @def_function.function 895 def f1(): 896 if not hasattr(f1, 'writer'): 897 f1.writer = summary_ops.create_file_writer_v2(logdir) 898 with f1.writer.as_default(): 899 summary_ops.write('tag', 1, step=1) 900 @def_function.function 901 def f2(): 902 if not hasattr(f2, 'writer'): 903 f2.writer = summary_ops.create_file_writer_v2(logdir) 904 with f2.writer.as_default(): 905 summary_ops.write('tag', 1, step=2) 906 with context.eager_mode(): 907 f1() 908 event_files = gfile.Glob(os.path.join(logdir, '*')) 909 self.assertEqual(1, len(event_files)) 910 file1 = event_files[0] 911 912 f2() 913 event_files = gfile.Glob(os.path.join(logdir, '*')) 914 self.assertEqual(2, len(event_files)) 915 event_files.remove(file1) 916 file2 = event_files[0] 917 918 # Extra writes to ensure interleaved usage works. 919 f1() 920 f2() 921 922 events = iter(events_from_file(file1)) 923 self.assertEqual('brain.Event:2', next(events).file_version) 924 self.assertEqual(1, next(events).step) 925 self.assertEqual(1, next(events).step) 926 self.assertRaises(StopIteration, lambda: next(events)) 927 events = iter(events_from_file(file2)) 928 self.assertEqual('brain.Event:2', next(events).file_version) 929 self.assertEqual(2, next(events).step) 930 self.assertEqual(2, next(events).step) 931 self.assertRaises(StopIteration, lambda: next(events)) 932 933 def testMaxQueue(self): 934 logdir = self.get_temp_dir() 935 with context.eager_mode(): 936 with summary_ops.create_file_writer_v2( 937 logdir, max_queue=1, flush_millis=999999).as_default(): 938 get_total = lambda: len(events_from_logdir(logdir)) 939 # Note: First tf.compat.v1.Event is always file_version. 940 self.assertEqual(1, get_total()) 941 summary_ops.write('tag', 1, step=0) 942 self.assertEqual(1, get_total()) 943 # Should flush after second summary since max_queue = 1 944 summary_ops.write('tag', 1, step=0) 945 self.assertEqual(3, get_total()) 946 947 def testWriterFlush(self): 948 logdir = self.get_temp_dir() 949 get_total = lambda: len(events_from_logdir(logdir)) 950 with context.eager_mode(): 951 writer = summary_ops.create_file_writer_v2( 952 logdir, max_queue=1000, flush_millis=1000000) 953 self.assertEqual(1, get_total()) # file_version Event 954 with writer.as_default(): 955 summary_ops.write('tag', 1, step=0) 956 self.assertEqual(1, get_total()) 957 writer.flush() 958 self.assertEqual(2, get_total()) 959 summary_ops.write('tag', 1, step=0) 960 self.assertEqual(2, get_total()) 961 # Exiting the "as_default()" should do an implicit flush 962 self.assertEqual(3, get_total()) 963 964 def testFlushFunction(self): 965 logdir = self.get_temp_dir() 966 with context.eager_mode(): 967 writer = summary_ops.create_file_writer_v2( 968 logdir, max_queue=999999, flush_millis=999999) 969 with writer.as_default(): 970 get_total = lambda: len(events_from_logdir(logdir)) 971 # Note: First tf.compat.v1.Event is always file_version. 972 self.assertEqual(1, get_total()) 973 summary_ops.write('tag', 1, step=0) 974 summary_ops.write('tag', 1, step=0) 975 self.assertEqual(1, get_total()) 976 summary_ops.flush() 977 self.assertEqual(3, get_total()) 978 # Test "writer" parameter 979 summary_ops.write('tag', 1, step=0) 980 self.assertEqual(3, get_total()) 981 summary_ops.flush(writer=writer) 982 self.assertEqual(4, get_total()) 983 summary_ops.write('tag', 1, step=0) 984 self.assertEqual(4, get_total()) 985 summary_ops.flush(writer=writer._resource) # pylint:disable=protected-access 986 self.assertEqual(5, get_total()) 987 988 @test_util.assert_no_new_tensors 989 def testNoMemoryLeak_graphMode(self): 990 logdir = self.get_temp_dir() 991 with context.graph_mode(), ops.Graph().as_default(): 992 summary_ops.create_file_writer_v2(logdir) 993 994 @test_util.assert_no_new_pyobjects_executing_eagerly 995 def testNoMemoryLeak_eagerMode(self): 996 logdir = self.get_temp_dir() 997 with summary_ops.create_file_writer_v2(logdir).as_default(): 998 summary_ops.write('tag', 1, step=0) 999 1000 def testClose_preventsLaterUse(self): 1001 logdir = self.get_temp_dir() 1002 with context.eager_mode(): 1003 writer = summary_ops.create_file_writer_v2(logdir) 1004 writer.close() 1005 writer.close() # redundant close() is a no-op 1006 writer.flush() # redundant flush() is a no-op 1007 with self.assertRaisesRegex(RuntimeError, 'already closed'): 1008 writer.init() 1009 with self.assertRaisesRegex(RuntimeError, 'already closed'): 1010 with writer.as_default(): 1011 self.fail('should not get here') 1012 with self.assertRaisesRegex(RuntimeError, 'already closed'): 1013 writer.set_as_default() 1014 1015 def testClose_closesOpenFile(self): 1016 try: 1017 import psutil # pylint: disable=g-import-not-at-top 1018 except ImportError: 1019 raise unittest.SkipTest('test requires psutil') 1020 proc = psutil.Process() 1021 get_open_filenames = lambda: set(info[0] for info in proc.open_files()) 1022 logdir = self.get_temp_dir() 1023 with context.eager_mode(): 1024 writer = summary_ops.create_file_writer_v2(logdir) 1025 files = gfile.Glob(os.path.join(logdir, '*')) 1026 self.assertEqual(1, len(files)) 1027 eventfile = files[0] 1028 self.assertIn(eventfile, get_open_filenames()) 1029 writer.close() 1030 self.assertNotIn(eventfile, get_open_filenames()) 1031 1032 def testDereference_closesOpenFile(self): 1033 try: 1034 import psutil # pylint: disable=g-import-not-at-top 1035 except ImportError: 1036 raise unittest.SkipTest('test requires psutil') 1037 proc = psutil.Process() 1038 get_open_filenames = lambda: set(info[0] for info in proc.open_files()) 1039 logdir = self.get_temp_dir() 1040 with context.eager_mode(): 1041 writer = summary_ops.create_file_writer_v2(logdir) 1042 files = gfile.Glob(os.path.join(logdir, '*')) 1043 self.assertEqual(1, len(files)) 1044 eventfile = files[0] 1045 self.assertIn(eventfile, get_open_filenames()) 1046 del writer 1047 self.assertNotIn(eventfile, get_open_filenames()) 1048 1049 1050class SummaryWriterSavedModelTest(test_util.TensorFlowTestCase): 1051 1052 def testWriter_savedAsModuleProperty_loadInEagerMode(self): 1053 with context.eager_mode(): 1054 class Model(module.Module): 1055 1056 def __init__(self, model_dir): 1057 self._writer = summary_ops.create_file_writer_v2( 1058 model_dir, experimental_trackable=True) 1059 1060 @def_function.function(input_signature=[ 1061 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64) 1062 ]) 1063 def train(self, step): 1064 with self._writer.as_default(): 1065 summary_ops.write('tag', 'foo', step=step) 1066 return constant_op.constant(0) 1067 1068 logdir = self.get_temp_dir() 1069 to_export = Model(logdir) 1070 pre_save_files = set(events_from_multifile_logdir(logdir)) 1071 export_dir = os.path.join(logdir, 'export') 1072 saved_model_save.save( 1073 to_export, export_dir, signatures={'train': to_export.train}) 1074 1075 # Reset context to ensure we don't share any resources with saving code. 1076 context._reset_context() # pylint: disable=protected-access 1077 with context.eager_mode(): 1078 restored = saved_model_load.load(export_dir) 1079 restored.train(1) 1080 restored.train(2) 1081 post_restore_files = set(events_from_multifile_logdir(logdir)) 1082 restored2 = saved_model_load.load(export_dir) 1083 restored2.train(3) 1084 restored2.train(4) 1085 files_to_events = events_from_multifile_logdir(logdir) 1086 post_restore2_files = set(files_to_events) 1087 self.assertLen(files_to_events, 3) 1088 def unwrap_singleton(iterable): 1089 self.assertLen(iterable, 1) 1090 return next(iter(iterable)) 1091 restore_file = unwrap_singleton(post_restore_files - pre_save_files) 1092 restore2_file = unwrap_singleton(post_restore2_files - post_restore_files) 1093 restore_events = files_to_events[restore_file] 1094 restore2_events = files_to_events[restore2_file] 1095 self.assertLen(restore_events, 3) 1096 self.assertEqual(1, restore_events[1].step) 1097 self.assertEqual(2, restore_events[2].step) 1098 self.assertLen(restore2_events, 3) 1099 self.assertEqual(3, restore2_events[1].step) 1100 self.assertEqual(4, restore2_events[2].step) 1101 1102 def testWriter_savedAsModuleProperty_loadInGraphMode(self): 1103 with context.eager_mode(): 1104 1105 class Model(module.Module): 1106 1107 def __init__(self, model_dir): 1108 self._writer = summary_ops.create_file_writer_v2( 1109 model_dir, experimental_trackable=True) 1110 1111 @def_function.function(input_signature=[ 1112 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64) 1113 ]) 1114 def train(self, step): 1115 with self._writer.as_default(): 1116 summary_ops.write('tag', 'foo', step=step) 1117 return constant_op.constant(0) 1118 1119 logdir = self.get_temp_dir() 1120 to_export = Model(logdir) 1121 pre_save_files = set(events_from_multifile_logdir(logdir)) 1122 export_dir = os.path.join(logdir, 'export') 1123 saved_model_save.save( 1124 to_export, export_dir, signatures={'train': to_export.train}) 1125 1126 # Reset context to ensure we don't share any resources with saving code. 1127 context._reset_context() # pylint: disable=protected-access 1128 1129 def load_and_run_model(sess, input_values): 1130 """Load and run the SavedModel signature in the TF 1.x style.""" 1131 model = saved_model_loader.load(sess, [tag_constants.SERVING], export_dir) 1132 signature = model.signature_def['train'] 1133 inputs = list(signature.inputs.values()) 1134 assert len(inputs) == 1, inputs 1135 outputs = list(signature.outputs.values()) 1136 assert len(outputs) == 1, outputs 1137 input_tensor = sess.graph.get_tensor_by_name(inputs[0].name) 1138 output_tensor = sess.graph.get_tensor_by_name(outputs[0].name) 1139 for v in input_values: 1140 sess.run(output_tensor, feed_dict={input_tensor: v}) 1141 1142 with context.graph_mode(), ops.Graph().as_default(): 1143 # Since writer shared_name is fixed, within a single session, all loads of 1144 # this SavedModel will refer to a single writer resouce, so it will be 1145 # initialized only once and write to a single file. 1146 with self.session() as sess: 1147 load_and_run_model(sess, [1, 2]) 1148 load_and_run_model(sess, [3, 4]) 1149 post_restore_files = set(events_from_multifile_logdir(logdir)) 1150 # New session will recreate the resource and write to a second file. 1151 with self.session() as sess: 1152 load_and_run_model(sess, [5, 6]) 1153 files_to_events = events_from_multifile_logdir(logdir) 1154 post_restore2_files = set(files_to_events) 1155 1156 self.assertLen(files_to_events, 3) 1157 def unwrap_singleton(iterable): 1158 self.assertLen(iterable, 1) 1159 return next(iter(iterable)) 1160 restore_file = unwrap_singleton(post_restore_files - pre_save_files) 1161 restore2_file = unwrap_singleton(post_restore2_files - post_restore_files) 1162 restore_events = files_to_events[restore_file] 1163 restore2_events = files_to_events[restore2_file] 1164 self.assertLen(restore_events, 5) 1165 self.assertEqual(1, restore_events[1].step) 1166 self.assertEqual(2, restore_events[2].step) 1167 self.assertEqual(3, restore_events[3].step) 1168 self.assertEqual(4, restore_events[4].step) 1169 self.assertLen(restore2_events, 3) 1170 self.assertEqual(5, restore2_events[1].step) 1171 self.assertEqual(6, restore2_events[2].step) 1172 1173 1174class NoopWriterTest(test_util.TensorFlowTestCase): 1175 1176 def testNoopWriter_doesNothing(self): 1177 logdir = self.get_temp_dir() 1178 with context.eager_mode(): 1179 writer = summary_ops.create_noop_writer() 1180 writer.init() 1181 with writer.as_default(): 1182 result = summary_ops.write('test', 1.0, step=0) 1183 writer.flush() 1184 writer.close() 1185 self.assertFalse(result) # Should have found no active writer 1186 files = gfile.Glob(os.path.join(logdir, '*')) 1187 self.assertLen(files, 0) 1188 1189 def testNoopWriter_asNestedContext_isTransparent(self): 1190 logdir = self.get_temp_dir() 1191 with context.eager_mode(): 1192 writer = summary_ops.create_file_writer_v2(logdir) 1193 noop_writer = summary_ops.create_noop_writer() 1194 with writer.as_default(): 1195 result1 = summary_ops.write('first', 1.0, step=0) 1196 with noop_writer.as_default(): 1197 result2 = summary_ops.write('second', 1.0, step=0) 1198 result3 = summary_ops.write('third', 1.0, step=0) 1199 # All ops should have written, including the one inside the no-op writer, 1200 # since it doesn't actively *disable* writing - it just behaves as if that 1201 # entire `with` block wasn't there at all. 1202 self.assertAllEqual([result1, result2, result3], [True, True, True]) 1203 1204 def testNoopWriter_setAsDefault(self): 1205 try: 1206 with context.eager_mode(): 1207 writer = summary_ops.create_noop_writer() 1208 writer.set_as_default() 1209 result = summary_ops.write('test', 1.0, step=0) 1210 self.assertFalse(result) # Should have found no active writer 1211 finally: 1212 # Ensure we clean up no matter how the test executes. 1213 summary_ops._summary_state.writer = None # pylint: disable=protected-access 1214 1215 1216class SummaryOpsTest(test_util.TensorFlowTestCase): 1217 1218 def tearDown(self): 1219 summary_ops.trace_off() 1220 1221 def exec_summary_op(self, summary_op_fn): 1222 assert context.executing_eagerly() 1223 logdir = self.get_temp_dir() 1224 writer = summary_ops.create_file_writer_v2(logdir) 1225 with writer.as_default(): 1226 summary_op_fn() 1227 writer.close() 1228 events = events_from_logdir(logdir) 1229 return events[1] 1230 1231 def run_metadata(self, *args, **kwargs): 1232 assert context.executing_eagerly() 1233 logdir = self.get_temp_dir() 1234 writer = summary_ops.create_file_writer_v2(logdir) 1235 with writer.as_default(): 1236 summary_ops.run_metadata(*args, **kwargs) 1237 writer.close() 1238 events = events_from_logdir(logdir) 1239 return events[1] 1240 1241 def run_metadata_graphs(self, *args, **kwargs): 1242 assert context.executing_eagerly() 1243 logdir = self.get_temp_dir() 1244 writer = summary_ops.create_file_writer_v2(logdir) 1245 with writer.as_default(): 1246 summary_ops.run_metadata_graphs(*args, **kwargs) 1247 writer.close() 1248 events = events_from_logdir(logdir) 1249 return events[1] 1250 1251 def create_run_metadata(self): 1252 step_stats = step_stats_pb2.StepStats(dev_stats=[ 1253 step_stats_pb2.DeviceStepStats( 1254 device='cpu:0', 1255 node_stats=[step_stats_pb2.NodeExecStats(node_name='hello')]) 1256 ]) 1257 return config_pb2.RunMetadata( 1258 function_graphs=[ 1259 config_pb2.RunMetadata.FunctionGraphs( 1260 pre_optimization_graph=graph_pb2.GraphDef( 1261 node=[node_def_pb2.NodeDef(name='foo')])) 1262 ], 1263 step_stats=step_stats) 1264 1265 def run_trace(self, f, step=1): 1266 assert context.executing_eagerly() 1267 logdir = self.get_temp_dir() 1268 writer = summary_ops.create_file_writer_v2(logdir) 1269 summary_ops.trace_on(graph=True, profiler=False) 1270 with writer.as_default(): 1271 f() 1272 summary_ops.trace_export(name='foo', step=step) 1273 writer.close() 1274 events = events_from_logdir(logdir) 1275 return events[1] 1276 1277 @test_util.run_v2_only 1278 def testRunMetadata_usesNameAsTag(self): 1279 meta = config_pb2.RunMetadata() 1280 1281 with ops.name_scope('foo', skip_on_eager=False): 1282 event = self.run_metadata(name='my_name', data=meta, step=1) 1283 first_val = event.summary.value[0] 1284 1285 self.assertEqual('foo/my_name', first_val.tag) 1286 1287 @test_util.run_v2_only 1288 def testRunMetadata_summaryMetadata(self): 1289 expected_summary_metadata = """ 1290 plugin_data { 1291 plugin_name: "graph_run_metadata" 1292 content: "1" 1293 } 1294 """ 1295 meta = config_pb2.RunMetadata() 1296 event = self.run_metadata(name='my_name', data=meta, step=1) 1297 actual_summary_metadata = event.summary.value[0].metadata 1298 self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata) 1299 1300 @test_util.run_v2_only 1301 def testRunMetadata_wholeRunMetadata(self): 1302 expected_run_metadata = """ 1303 step_stats { 1304 dev_stats { 1305 device: "cpu:0" 1306 node_stats { 1307 node_name: "hello" 1308 } 1309 } 1310 } 1311 function_graphs { 1312 pre_optimization_graph { 1313 node { 1314 name: "foo" 1315 } 1316 } 1317 } 1318 """ 1319 meta = self.create_run_metadata() 1320 event = self.run_metadata(name='my_name', data=meta, step=1) 1321 first_val = event.summary.value[0] 1322 1323 actual_run_metadata = config_pb2.RunMetadata.FromString( 1324 first_val.tensor.string_val[0]) 1325 self.assertProtoEquals(expected_run_metadata, actual_run_metadata) 1326 1327 @test_util.run_v2_only 1328 def testRunMetadata_usesDefaultStep(self): 1329 meta = config_pb2.RunMetadata() 1330 try: 1331 summary_ops.set_step(42) 1332 event = self.run_metadata(name='my_name', data=meta) 1333 self.assertEqual(42, event.step) 1334 finally: 1335 # Reset to default state for other tests. 1336 summary_ops.set_step(None) 1337 1338 @test_util.run_v2_only 1339 def testRunMetadataGraph_usesNameAsTag(self): 1340 meta = config_pb2.RunMetadata() 1341 1342 with ops.name_scope('foo', skip_on_eager=False): 1343 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 1344 first_val = event.summary.value[0] 1345 1346 self.assertEqual('foo/my_name', first_val.tag) 1347 1348 @test_util.run_v2_only 1349 def testRunMetadataGraph_summaryMetadata(self): 1350 expected_summary_metadata = """ 1351 plugin_data { 1352 plugin_name: "graph_run_metadata_graph" 1353 content: "1" 1354 } 1355 """ 1356 meta = config_pb2.RunMetadata() 1357 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 1358 actual_summary_metadata = event.summary.value[0].metadata 1359 self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata) 1360 1361 @test_util.run_v2_only 1362 def testRunMetadataGraph_runMetadataFragment(self): 1363 expected_run_metadata = """ 1364 function_graphs { 1365 pre_optimization_graph { 1366 node { 1367 name: "foo" 1368 } 1369 } 1370 } 1371 """ 1372 meta = self.create_run_metadata() 1373 1374 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 1375 first_val = event.summary.value[0] 1376 1377 actual_run_metadata = config_pb2.RunMetadata.FromString( 1378 first_val.tensor.string_val[0]) 1379 self.assertProtoEquals(expected_run_metadata, actual_run_metadata) 1380 1381 @test_util.run_v2_only 1382 def testRunMetadataGraph_usesDefaultStep(self): 1383 meta = config_pb2.RunMetadata() 1384 try: 1385 summary_ops.set_step(42) 1386 event = self.run_metadata_graphs(name='my_name', data=meta) 1387 self.assertEqual(42, event.step) 1388 finally: 1389 # Reset to default state for other tests. 1390 summary_ops.set_step(None) 1391 1392 @test_util.run_v2_only 1393 def testTrace(self): 1394 1395 @def_function.function 1396 def f(): 1397 x = constant_op.constant(2) 1398 y = constant_op.constant(3) 1399 return x**y 1400 1401 event = self.run_trace(f) 1402 1403 first_val = event.summary.value[0] 1404 actual_run_metadata = config_pb2.RunMetadata.FromString( 1405 first_val.tensor.string_val[0]) 1406 1407 # Content of function_graphs is large and, for instance, device can change. 1408 self.assertTrue(hasattr(actual_run_metadata, 'function_graphs')) 1409 1410 @test_util.run_v2_only 1411 def testTrace_cannotEnableTraceInFunction(self): 1412 1413 @def_function.function 1414 def f(): 1415 summary_ops.trace_on(graph=True, profiler=False) 1416 x = constant_op.constant(2) 1417 y = constant_op.constant(3) 1418 return x**y 1419 1420 with test.mock.patch.object(logging, 'warn') as mock_log: 1421 f() 1422 self.assertRegex( 1423 str(mock_log.call_args), 'Cannot enable trace inside a tf.function.') 1424 1425 @test_util.run_v2_only 1426 def testTrace_cannotEnableTraceInGraphMode(self): 1427 with test.mock.patch.object(logging, 'warn') as mock_log: 1428 with context.graph_mode(): 1429 summary_ops.trace_on(graph=True, profiler=False) 1430 self.assertRegex( 1431 str(mock_log.call_args), 'Must enable trace in eager mode.') 1432 1433 @test_util.run_v2_only 1434 def testTrace_cannotExportTraceWithoutTrace(self): 1435 with six.assertRaisesRegex(self, ValueError, 1436 'Must enable trace before export.'): 1437 summary_ops.trace_export(name='foo', step=1) 1438 1439 @test_util.run_v2_only 1440 def testTrace_cannotExportTraceInFunction(self): 1441 summary_ops.trace_on(graph=True, profiler=False) 1442 1443 @def_function.function 1444 def f(): 1445 x = constant_op.constant(2) 1446 y = constant_op.constant(3) 1447 summary_ops.trace_export(name='foo', step=1) 1448 return x**y 1449 1450 with test.mock.patch.object(logging, 'warn') as mock_log: 1451 f() 1452 self.assertRegex( 1453 str(mock_log.call_args), 'Cannot export trace inside a tf.function.') 1454 1455 @test_util.run_v2_only 1456 def testTrace_cannotExportTraceInGraphMode(self): 1457 with test.mock.patch.object(logging, 'warn') as mock_log: 1458 with context.graph_mode(): 1459 summary_ops.trace_export(name='foo', step=1) 1460 self.assertRegex( 1461 str(mock_log.call_args), 1462 'Can only export trace while executing eagerly.') 1463 1464 @test_util.run_v2_only 1465 def testTrace_usesDefaultStep(self): 1466 1467 @def_function.function 1468 def f(): 1469 x = constant_op.constant(2) 1470 y = constant_op.constant(3) 1471 return x**y 1472 1473 try: 1474 summary_ops.set_step(42) 1475 event = self.run_trace(f, step=None) 1476 self.assertEqual(42, event.step) 1477 finally: 1478 # Reset to default state for other tests. 1479 summary_ops.set_step(None) 1480 1481 @test_util.run_v2_only 1482 def testTrace_withProfiler(self): 1483 1484 @def_function.function 1485 def f(): 1486 x = constant_op.constant(2) 1487 y = constant_op.constant(3) 1488 return x**y 1489 1490 assert context.executing_eagerly() 1491 logdir = self.get_temp_dir() 1492 writer = summary_ops.create_file_writer_v2(logdir) 1493 summary_ops.trace_on(graph=True, profiler=True) 1494 profiler_outdir = self.get_temp_dir() 1495 with writer.as_default(): 1496 f() 1497 summary_ops.trace_export( 1498 name='foo', step=1, profiler_outdir=profiler_outdir) 1499 writer.close() 1500 1501 @test_util.run_v2_only 1502 def testGraph_graph(self): 1503 1504 @def_function.function 1505 def f(): 1506 x = constant_op.constant(2) 1507 y = constant_op.constant(3) 1508 return x**y 1509 1510 def summary_op_fn(): 1511 summary_ops.graph(f.get_concrete_function().graph) 1512 1513 event = self.exec_summary_op(summary_op_fn) 1514 self.assertIsNotNone(event.graph_def) 1515 1516 @test_util.run_v2_only 1517 def testGraph_graphDef(self): 1518 1519 @def_function.function 1520 def f(): 1521 x = constant_op.constant(2) 1522 y = constant_op.constant(3) 1523 return x**y 1524 1525 def summary_op_fn(): 1526 summary_ops.graph(f.get_concrete_function().graph.as_graph_def()) 1527 1528 event = self.exec_summary_op(summary_op_fn) 1529 self.assertIsNotNone(event.graph_def) 1530 1531 @test_util.run_v2_only 1532 def testGraph_invalidData(self): 1533 def summary_op_fn(): 1534 summary_ops.graph('hello') 1535 1536 with self.assertRaisesRegex( 1537 ValueError, 1538 r'\'graph_data\' is not tf.Graph or tf.compat.v1.GraphDef', 1539 ): 1540 self.exec_summary_op(summary_op_fn) 1541 1542 @test_util.run_v2_only 1543 def testGraph_fromGraphMode(self): 1544 1545 @def_function.function 1546 def f(): 1547 x = constant_op.constant(2) 1548 y = constant_op.constant(3) 1549 return x**y 1550 1551 @def_function.function 1552 def g(graph): 1553 summary_ops.graph(graph) 1554 1555 def summary_op_fn(): 1556 graph_def = f.get_concrete_function().graph.as_graph_def(add_shapes=True) 1557 func_graph = constant_op.constant(graph_def.SerializeToString()) 1558 g(func_graph) 1559 1560 with self.assertRaisesRegex( 1561 ValueError, 1562 r'graph\(\) cannot be invoked inside a graph context.', 1563 ): 1564 self.exec_summary_op(summary_op_fn) 1565 1566 1567def events_from_file(filepath): 1568 """Returns all events in a single event file. 1569 1570 Args: 1571 filepath: Path to the event file. 1572 1573 Returns: 1574 A list of all tf.Event protos in the event file. 1575 """ 1576 records = list(tf_record.tf_record_iterator(filepath)) 1577 result = [] 1578 for r in records: 1579 event = event_pb2.Event() 1580 event.ParseFromString(r) 1581 result.append(event) 1582 return result 1583 1584 1585def events_from_logdir(logdir): 1586 """Returns all events in the single eventfile in logdir. 1587 1588 Args: 1589 logdir: The directory in which the single event file is sought. 1590 1591 Returns: 1592 A list of all tf.Event protos from the single event file. 1593 1594 Raises: 1595 AssertionError: If logdir does not contain exactly one file. 1596 """ 1597 assert gfile.Exists(logdir) 1598 files = gfile.ListDirectory(logdir) 1599 assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files 1600 return events_from_file(os.path.join(logdir, files[0])) 1601 1602 1603def events_from_multifile_logdir(logdir): 1604 """Returns map of filename to events for all `tfevents` files in the logdir. 1605 1606 Args: 1607 logdir: The directory from which to load events. 1608 1609 Returns: 1610 A dict mapping from relative filenames to lists of tf.Event protos. 1611 1612 Raises: 1613 AssertionError: If logdir does not contain exactly one file. 1614 """ 1615 assert gfile.Exists(logdir) 1616 files = [file for file in gfile.ListDirectory(logdir) if 'tfevents' in file] 1617 return {file: events_from_file(os.path.join(logdir, file)) for file in files} 1618 1619 1620def to_numpy(summary_value): 1621 return tensor_util.MakeNdarray(summary_value.tensor) 1622 1623 1624if __name__ == '__main__': 1625 test.main() 1626