1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Fake summary writer for unit tests.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.core.framework import summary_pb2 21from tensorflow.python.framework import test_util 22from tensorflow.python.summary.writer import writer 23from tensorflow.python.summary.writer import writer_cache 24 25 26# TODO(ptucker): Replace with mock framework. 27class FakeSummaryWriter(object): 28 """Fake summary writer.""" 29 30 _replaced_summary_writer = None 31 32 @classmethod 33 def install(cls): 34 if cls._replaced_summary_writer: 35 raise ValueError('FakeSummaryWriter already installed.') 36 cls._replaced_summary_writer = writer.FileWriter 37 writer.FileWriter = FakeSummaryWriter 38 writer_cache.FileWriter = FakeSummaryWriter 39 40 @classmethod 41 def uninstall(cls): 42 if not cls._replaced_summary_writer: 43 raise ValueError('FakeSummaryWriter not installed.') 44 writer.FileWriter = cls._replaced_summary_writer 45 writer_cache.FileWriter = cls._replaced_summary_writer 46 cls._replaced_summary_writer = None 47 48 def __init__(self, logdir, graph=None): 49 self._logdir = logdir 50 self._graph = graph 51 self._summaries = {} 52 self._added_graphs = [] 53 self._added_meta_graphs = [] 54 self._added_session_logs = [] 55 self._added_run_metadata = {} 56 57 @property 58 def summaries(self): 59 return self._summaries 60 61 def assert_summaries(self, 62 test_case, 63 expected_logdir=None, 64 expected_graph=None, 65 expected_summaries=None, 66 expected_added_graphs=None, 67 expected_added_meta_graphs=None, 68 expected_session_logs=None): 69 """Assert expected items have been added to summary writer.""" 70 if expected_logdir is not None: 71 test_case.assertEqual(expected_logdir, self._logdir) 72 if expected_graph is not None: 73 test_case.assertTrue(expected_graph is self._graph) 74 expected_summaries = expected_summaries or {} 75 for step in expected_summaries: 76 test_case.assertTrue( 77 step in self._summaries, 78 msg='Missing step %s from %s.' % (step, self._summaries.keys())) 79 actual_simple_values = {} 80 for step_summary in self._summaries[step]: 81 for v in step_summary.value: 82 # Ignore global_step/sec since it's written by Supervisor in a 83 # separate thread, so it's non-deterministic how many get written. 84 if 'global_step/sec' != v.tag: 85 actual_simple_values[v.tag] = v.simple_value 86 test_case.assertEqual(expected_summaries[step], actual_simple_values) 87 if expected_added_graphs is not None: 88 test_case.assertEqual(expected_added_graphs, self._added_graphs) 89 if expected_added_meta_graphs is not None: 90 test_case.assertEqual(len(expected_added_meta_graphs), 91 len(self._added_meta_graphs)) 92 for expected, actual in zip(expected_added_meta_graphs, 93 self._added_meta_graphs): 94 test_util.assert_meta_graph_protos_equal(test_case, expected, actual) 95 if expected_session_logs is not None: 96 test_case.assertEqual(expected_session_logs, self._added_session_logs) 97 98 def add_summary(self, summ, current_global_step): 99 """Add summary.""" 100 if isinstance(summ, bytes): 101 summary_proto = summary_pb2.Summary() 102 summary_proto.ParseFromString(summ) 103 summ = summary_proto 104 if current_global_step in self._summaries: 105 step_summaries = self._summaries[current_global_step] 106 else: 107 step_summaries = [] 108 self._summaries[current_global_step] = step_summaries 109 step_summaries.append(summ) 110 111 # NOTE: Ignore global_step since its value is non-deterministic. 112 def add_graph(self, graph, global_step=None, graph_def=None): 113 """Add graph.""" 114 if (global_step is not None) and (global_step < 0): 115 raise ValueError('Invalid global_step %s.' % global_step) 116 if graph_def is not None: 117 raise ValueError('Unexpected graph_def %s.' % graph_def) 118 self._added_graphs.append(graph) 119 120 def add_meta_graph(self, meta_graph_def, global_step=None): 121 """Add metagraph.""" 122 if (global_step is not None) and (global_step < 0): 123 raise ValueError('Invalid global_step %s.' % global_step) 124 self._added_meta_graphs.append(meta_graph_def) 125 126 # NOTE: Ignore global_step since its value is non-deterministic. 127 def add_session_log(self, session_log, global_step=None): 128 # pylint: disable=unused-argument 129 self._added_session_logs.append(session_log) 130 131 def add_run_metadata(self, run_metadata, tag, global_step=None): 132 if (global_step is not None) and (global_step < 0): 133 raise ValueError('Invalid global_step %s.' % global_step) 134 self._added_run_metadata[tag] = run_metadata 135 136 def flush(self): 137 pass 138 139 def reopen(self): 140 pass 141 142 def close(self): 143 pass 144