• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the actual serialized proto output of the V1 tf.summary ops.
16
17The tensor, audio, and image ops have dedicated tests in adjacent files. The
18overall tf.summary API surface also has its own tests in summary_test.py that
19check calling the API methods but not the exact serialized proto output.
20"""
21
22from tensorflow.core.framework import summary_pb2
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops import logging_ops
27from tensorflow.python.platform import test
28from tensorflow.python.summary import summary
29
30
31class SummaryV1OpsTest(test.TestCase):
32
33  def _AsSummary(self, s):
34    summ = summary_pb2.Summary()
35    summ.ParseFromString(s)
36    return summ
37
38  def testScalarSummary(self):
39    with self.cached_session() as sess:
40      const = constant_op.constant([10.0, 20.0])
41      summ = logging_ops.scalar_summary(["c1", "c2"], const, name="mysumm")
42      value = self.evaluate(summ)
43    self.assertEqual([], summ.get_shape())
44    self.assertProtoEquals("""
45      value { tag: "c1" simple_value: 10.0 }
46      value { tag: "c2" simple_value: 20.0 }
47      """, self._AsSummary(value))
48
49  def testScalarSummaryDefaultName(self):
50    with self.cached_session() as sess:
51      const = constant_op.constant([10.0, 20.0])
52      summ = logging_ops.scalar_summary(["c1", "c2"], const)
53      value = self.evaluate(summ)
54    self.assertEqual([], summ.get_shape())
55    self.assertProtoEquals("""
56      value { tag: "c1" simple_value: 10.0 }
57      value { tag: "c2" simple_value: 20.0 }
58      """, self._AsSummary(value))
59
60  @test_util.run_deprecated_v1
61  def testMergeSummary(self):
62    with self.cached_session() as sess:
63      const = constant_op.constant(10.0)
64      summ1 = summary.histogram("h", const)
65      summ2 = logging_ops.scalar_summary("c", const)
66      merge = summary.merge([summ1, summ2])
67      value = self.evaluate(merge)
68    self.assertEqual([], merge.get_shape())
69    self.assertProtoEquals("""
70      value {
71        tag: "h"
72        histo {
73          min: 10.0
74          max: 10.0
75          num: 1.0
76          sum: 10.0
77          sum_squares: 100.0
78          bucket_limit: 9.93809490288
79          bucket_limit: 10.9319043932
80          bucket_limit: 1.7976931348623157e+308
81          bucket: 0.0
82          bucket: 1.0
83          bucket: 0.0
84        }
85      }
86      value { tag: "c" simple_value: 10.0 }
87    """, self._AsSummary(value))
88
89  def testMergeAllSummaries(self):
90    with ops.Graph().as_default():
91      const = constant_op.constant(10.0)
92      summ1 = summary.histogram("h", const)
93      summ2 = summary.scalar("o", const, collections=["foo_key"])
94      summ3 = summary.scalar("c", const)
95      merge = summary.merge_all()
96      self.assertEqual("MergeSummary", merge.op.type)
97      self.assertEqual(2, len(merge.op.inputs))
98      self.assertEqual(summ1, merge.op.inputs[0])
99      self.assertEqual(summ3, merge.op.inputs[1])
100      merge = summary.merge_all("foo_key")
101      self.assertEqual("MergeSummary", merge.op.type)
102      self.assertEqual(1, len(merge.op.inputs))
103      self.assertEqual(summ2, merge.op.inputs[0])
104      self.assertTrue(summary.merge_all("bar_key") is None)
105
106
107if __name__ == "__main__":
108  test.main()
109