1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Test utilities.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20import glob 21import os 22import numpy as np 23from tensorflow.core.framework import summary_pb2 24from tensorflow.python.training import summary_io 25 26 27def assert_summary(expected_tags, expected_simple_values, summary_proto): 28 """Asserts summary contains the specified tags and values. 29 30 Args: 31 expected_tags: All tags in summary. 32 expected_simple_values: Simply values for some tags. 33 summary_proto: Summary to validate. 34 35 Raises: 36 ValueError: if expectations are not met. 37 """ 38 actual_tags = set() 39 for value in summary_proto.value: 40 actual_tags.add(value.tag) 41 if value.tag in expected_simple_values: 42 expected = expected_simple_values[value.tag] 43 actual = value.simple_value 44 np.testing.assert_almost_equal( 45 actual, expected, decimal=2, err_msg=value.tag) 46 expected_tags = set(expected_tags) 47 if expected_tags != actual_tags: 48 raise ValueError('Expected tags %s, got %s.' % (expected_tags, actual_tags)) 49 50 51def to_summary_proto(summary_str): 52 """Create summary based on latest stats. 53 54 Args: 55 summary_str: Serialized summary. 56 Returns: 57 summary_pb2.Summary. 58 Raises: 59 ValueError: if tensor is not a valid summary tensor. 60 """ 61 summary = summary_pb2.Summary() 62 summary.ParseFromString(summary_str) 63 return summary 64 65 66# TODO(ptucker): Move to a non-test package? 67def latest_event_file(base_dir): 68 """Find latest event file in `base_dir`. 69 70 Args: 71 base_dir: Base directory in which TF event flies are stored. 72 Returns: 73 File path, or `None` if none exists. 74 """ 75 file_paths = glob.glob(os.path.join(base_dir, 'events.*')) 76 return sorted(file_paths)[-1] if file_paths else None 77 78 79def latest_events(base_dir): 80 """Parse events from latest event file in base_dir. 81 82 Args: 83 base_dir: Base directory in which TF event flies are stored. 84 Returns: 85 Iterable of event protos. 86 Raises: 87 ValueError: if no event files exist under base_dir. 88 """ 89 file_path = latest_event_file(base_dir) 90 return summary_io.summary_iterator(file_path) if file_path else [] 91 92 93def latest_summaries(base_dir): 94 """Parse summary events from latest event file in base_dir. 95 96 Args: 97 base_dir: Base directory in which TF event flies are stored. 98 Returns: 99 List of event protos. 100 Raises: 101 ValueError: if no event files exist under base_dir. 102 """ 103 return [e for e in latest_events(base_dir) if e.HasField('summary')] 104 105 106def simple_values_from_events(events, tags): 107 """Parse summaries from events with simple_value. 108 109 Args: 110 events: List of tensorflow.Event protos. 111 tags: List of string event tags corresponding to simple_value summaries. 112 Returns: 113 dict of tag:value. 114 Raises: 115 ValueError: if a summary with a specified tag does not contain simple_value. 116 """ 117 step_by_tag = {} 118 value_by_tag = {} 119 for e in events: 120 if e.HasField('summary'): 121 for v in e.summary.value: 122 tag = v.tag 123 if tag in tags: 124 if not v.HasField('simple_value'): 125 raise ValueError('Summary for %s is not a simple_value.' % tag) 126 # The events are mostly sorted in step order, but we explicitly check 127 # just in case. 128 if tag not in step_by_tag or e.step > step_by_tag[tag]: 129 step_by_tag[tag] = e.step 130 value_by_tag[tag] = v.simple_value 131 return value_by_tag 132