1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Fake summary writer for unit tests.""" 16from tensorflow.core.framework import summary_pb2 17from tensorflow.python.framework import test_util 18from tensorflow.python.summary.writer import writer 19from tensorflow.python.summary.writer import writer_cache 20 21 22# TODO(ptucker): Replace with mock framework. 23class FakeSummaryWriter(object): 24 """Fake summary writer.""" 25 26 _replaced_summary_writer = None 27 28 @classmethod 29 def install(cls): 30 if cls._replaced_summary_writer: 31 raise ValueError('FakeSummaryWriter already installed.') 32 cls._replaced_summary_writer = writer.FileWriter 33 writer.FileWriter = FakeSummaryWriter 34 writer_cache.FileWriter = FakeSummaryWriter 35 36 @classmethod 37 def uninstall(cls): 38 if not cls._replaced_summary_writer: 39 raise ValueError('FakeSummaryWriter not installed.') 40 writer.FileWriter = cls._replaced_summary_writer 41 writer_cache.FileWriter = cls._replaced_summary_writer 42 cls._replaced_summary_writer = None 43 44 def __init__(self, logdir, graph=None): 45 self._logdir = logdir 46 self._graph = graph 47 self._summaries = {} 48 self._added_graphs = [] 49 self._added_meta_graphs = [] 50 self._added_session_logs = [] 51 self._added_run_metadata = {} 52 53 @property 54 def summaries(self): 55 return self._summaries 56 57 def assert_summaries(self, 58 test_case, 59 expected_logdir=None, 60 expected_graph=None, 61 expected_summaries=None, 62 expected_added_graphs=None, 63 expected_added_meta_graphs=None, 64 expected_session_logs=None): 65 """Assert expected items have been added to summary writer.""" 66 if expected_logdir is not None: 67 test_case.assertEqual(expected_logdir, self._logdir) 68 if expected_graph is not None: 69 test_case.assertTrue(expected_graph is self._graph) 70 expected_summaries = expected_summaries or {} 71 for step in expected_summaries: 72 test_case.assertTrue( 73 step in self._summaries, 74 msg='Missing step %s from %s.' % (step, self._summaries.keys())) 75 actual_simple_values = {} 76 for step_summary in self._summaries[step]: 77 for v in step_summary.value: 78 # Ignore global_step/sec since it's written by Supervisor in a 79 # separate thread, so it's non-deterministic how many get written. 80 if 'global_step/sec' != v.tag: 81 actual_simple_values[v.tag] = v.simple_value 82 test_case.assertEqual(expected_summaries[step], actual_simple_values) 83 if expected_added_graphs is not None: 84 test_case.assertEqual(expected_added_graphs, self._added_graphs) 85 if expected_added_meta_graphs is not None: 86 test_case.assertEqual(len(expected_added_meta_graphs), 87 len(self._added_meta_graphs)) 88 for expected, actual in zip(expected_added_meta_graphs, 89 self._added_meta_graphs): 90 test_util.assert_meta_graph_protos_equal(test_case, expected, actual) 91 if expected_session_logs is not None: 92 test_case.assertEqual(expected_session_logs, self._added_session_logs) 93 94 def add_summary(self, summ, current_global_step): 95 """Add summary.""" 96 if isinstance(summ, bytes): 97 summary_proto = summary_pb2.Summary() 98 summary_proto.ParseFromString(summ) 99 summ = summary_proto 100 if current_global_step in self._summaries: 101 step_summaries = self._summaries[current_global_step] 102 else: 103 step_summaries = [] 104 self._summaries[current_global_step] = step_summaries 105 step_summaries.append(summ) 106 107 # NOTE: Ignore global_step since its value is non-deterministic. 108 def add_graph(self, graph, global_step=None, graph_def=None): 109 """Add graph.""" 110 if (global_step is not None) and (global_step < 0): 111 raise ValueError('Invalid global_step %s.' % global_step) 112 if graph_def is not None: 113 raise ValueError('Unexpected graph_def %s.' % graph_def) 114 self._added_graphs.append(graph) 115 116 def add_meta_graph(self, meta_graph_def, global_step=None): 117 """Add metagraph.""" 118 if (global_step is not None) and (global_step < 0): 119 raise ValueError('Invalid global_step %s.' % global_step) 120 self._added_meta_graphs.append(meta_graph_def) 121 122 # NOTE: Ignore global_step since its value is non-deterministic. 123 def add_session_log(self, session_log, global_step=None): 124 # pylint: disable=unused-argument 125 self._added_session_logs.append(session_log) 126 127 def add_run_metadata(self, run_metadata, tag, global_step=None): 128 if (global_step is not None) and (global_step < 0): 129 raise ValueError('Invalid global_step %s.' % global_step) 130 self._added_run_metadata[tag] = run_metadata 131 132 def flush(self): 133 pass 134 135 def reopen(self): 136 pass 137 138 def close(self): 139 pass 140