• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Test utilities."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20import glob
21import os
22import numpy as np
23from tensorflow.core.framework import summary_pb2
24from tensorflow.python.training import summary_io
25
26
27def assert_summary(expected_tags, expected_simple_values, summary_proto):
28  """Asserts summary contains the specified tags and values.
29
30  Args:
31    expected_tags: All tags in summary.
32    expected_simple_values: Simply values for some tags.
33    summary_proto: Summary to validate.
34
35  Raises:
36    ValueError: if expectations are not met.
37  """
38  actual_tags = set()
39  for value in summary_proto.value:
40    actual_tags.add(value.tag)
41    if value.tag in expected_simple_values:
42      expected = expected_simple_values[value.tag]
43      actual = value.simple_value
44      np.testing.assert_almost_equal(
45          actual, expected, decimal=2, err_msg=value.tag)
46  expected_tags = set(expected_tags)
47  if expected_tags != actual_tags:
48    raise ValueError('Expected tags %s, got %s.' % (expected_tags, actual_tags))
49
50
51def to_summary_proto(summary_str):
52  """Create summary based on latest stats.
53
54  Args:
55    summary_str: Serialized summary.
56  Returns:
57    summary_pb2.Summary.
58  Raises:
59    ValueError: if tensor is not a valid summary tensor.
60  """
61  summary = summary_pb2.Summary()
62  summary.ParseFromString(summary_str)
63  return summary
64
65
66# TODO(ptucker): Move to a non-test package?
67def latest_event_file(base_dir):
68  """Find latest event file in `base_dir`.
69
70  Args:
71    base_dir: Base directory in which TF event flies are stored.
72  Returns:
73    File path, or `None` if none exists.
74  """
75  file_paths = glob.glob(os.path.join(base_dir, 'events.*'))
76  return sorted(file_paths)[-1] if file_paths else None
77
78
79def latest_events(base_dir):
80  """Parse events from latest event file in base_dir.
81
82  Args:
83    base_dir: Base directory in which TF event flies are stored.
84  Returns:
85    Iterable of event protos.
86  Raises:
87    ValueError: if no event files exist under base_dir.
88  """
89  file_path = latest_event_file(base_dir)
90  return summary_io.summary_iterator(file_path) if file_path else []
91
92
93def latest_summaries(base_dir):
94  """Parse summary events from latest event file in base_dir.
95
96  Args:
97    base_dir: Base directory in which TF event flies are stored.
98  Returns:
99    List of event protos.
100  Raises:
101    ValueError: if no event files exist under base_dir.
102  """
103  return [e for e in latest_events(base_dir) if e.HasField('summary')]
104
105
106def simple_values_from_events(events, tags):
107  """Parse summaries from events with simple_value.
108
109  Args:
110    events: List of tensorflow.Event protos.
111    tags: List of string event tags corresponding to simple_value summaries.
112  Returns:
113    dict of tag:value.
114  Raises:
115   ValueError: if a summary with a specified tag does not contain simple_value.
116  """
117  step_by_tag = {}
118  value_by_tag = {}
119  for e in events:
120    if e.HasField('summary'):
121      for v in e.summary.value:
122        tag = v.tag
123        if tag in tags:
124          if not v.HasField('simple_value'):
125            raise ValueError('Summary for %s is not a simple_value.' % tag)
126          # The events are mostly sorted in step order, but we explicitly check
127          # just in case.
128          if tag not in step_by_tag or e.step > step_by_tag[tag]:
129            step_by_tag[tag] = e.step
130            value_by_tag[tag] = v.simple_value
131  return value_by_tag
132