• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BAvSIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for summary V1 tensor op."""
16
17import numpy as np
18
19from tensorflow.core.framework import summary_pb2
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.platform import test
25from tensorflow.python.summary import summary as summary_lib
26
27
28class SummaryV1TensorOpTest(test.TestCase):
29
30  def _SummarySingleValue(self, s):
31    summ = summary_pb2.Summary()
32    summ.ParseFromString(s)
33    self.assertEqual(len(summ.value), 1)
34    return summ.value[0]
35
36  def _AssertNumpyEq(self, actual, expected):
37    self.assertTrue(np.array_equal(actual, expected))
38
39  def testTags(self):
40    with self.cached_session() as sess:
41      c = constant_op.constant(1)
42      s1 = summary_lib.tensor_summary("s1", c)
43      with ops.name_scope("foo", skip_on_eager=False):
44        s2 = summary_lib.tensor_summary("s2", c)
45        with ops.name_scope("zod", skip_on_eager=False):
46          s3 = summary_lib.tensor_summary("s3", c)
47          s4 = summary_lib.tensor_summary("TensorSummary", c)
48      summ1, summ2, summ3, summ4 = self.evaluate([s1, s2, s3, s4])
49
50    v1 = self._SummarySingleValue(summ1)
51    self.assertEqual(v1.tag, "s1")
52
53    v2 = self._SummarySingleValue(summ2)
54    self.assertEqual(v2.tag, "foo/s2")
55
56    v3 = self._SummarySingleValue(summ3)
57    self.assertEqual(v3.tag, "foo/zod/s3")
58
59    v4 = self._SummarySingleValue(summ4)
60    self.assertEqual(v4.tag, "foo/zod/TensorSummary")
61
62  def testScalarSummary(self):
63    with self.cached_session() as sess:
64      const = constant_op.constant(10.0)
65      summ = summary_lib.tensor_summary("foo", const)
66      result = self.evaluate(summ)
67
68    value = self._SummarySingleValue(result)
69    n = tensor_util.MakeNdarray(value.tensor)
70    self._AssertNumpyEq(n, 10)
71
72  def testStringSummary(self):
73    s = b"foobar"
74    with self.cached_session() as sess:
75      const = constant_op.constant(s)
76      summ = summary_lib.tensor_summary("foo", const)
77      result = self.evaluate(summ)
78
79    value = self._SummarySingleValue(result)
80    n = tensor_util.MakeNdarray(value.tensor)
81    self._AssertNumpyEq(n, s)
82
83  def testManyScalarSummary(self):
84    with self.cached_session() as sess:
85      const = array_ops.ones([5, 5, 5])
86      summ = summary_lib.tensor_summary("foo", const)
87      result = self.evaluate(summ)
88    value = self._SummarySingleValue(result)
89    n = tensor_util.MakeNdarray(value.tensor)
90    self._AssertNumpyEq(n, np.ones([5, 5, 5]))
91
92  def testManyStringSummary(self):
93    strings = [[b"foo bar", b"baz"], [b"zoink", b"zod"]]
94    with self.cached_session() as sess:
95      const = constant_op.constant(strings)
96      summ = summary_lib.tensor_summary("foo", const)
97      result = self.evaluate(summ)
98    value = self._SummarySingleValue(result)
99    n = tensor_util.MakeNdarray(value.tensor)
100    self._AssertNumpyEq(n, strings)
101
102  def testManyBools(self):
103    bools = [True, True, True, False, False, False]
104    with self.cached_session() as sess:
105      const = constant_op.constant(bools)
106      summ = summary_lib.tensor_summary("foo", const)
107      result = self.evaluate(summ)
108
109    value = self._SummarySingleValue(result)
110    n = tensor_util.MakeNdarray(value.tensor)
111    self._AssertNumpyEq(n, bools)
112
113  def testSummaryDescriptionAndDisplayName(self):
114    with self.cached_session() as sess:
115
116      def get_description(summary_op):
117        summ_str = self.evaluate(summary_op)
118        summ = summary_pb2.Summary()
119        summ.ParseFromString(summ_str)
120        return summ.value[0].metadata
121
122      const = constant_op.constant(1)
123      # Default case; no description or display name
124      simple_summary = summary_lib.tensor_summary("simple", const)
125
126      descr = get_description(simple_summary)
127      self.assertEqual(descr.display_name, "")
128      self.assertEqual(descr.summary_description, "")
129
130      # Values are provided via function args
131      with_values = summary_lib.tensor_summary(
132          "simple",
133          const,
134          display_name="my name",
135          summary_description="my description")
136
137      descr = get_description(with_values)
138      self.assertEqual(descr.display_name, "my name")
139      self.assertEqual(descr.summary_description, "my description")
140
141      # Values are provided via the SummaryMetadata arg
142      metadata = summary_pb2.SummaryMetadata()
143      metadata.display_name = "my name"
144      metadata.summary_description = "my description"
145
146      with_metadata = summary_lib.tensor_summary(
147          "simple", const, summary_metadata=metadata)
148      descr = get_description(with_metadata)
149      self.assertEqual(descr.display_name, "my name")
150      self.assertEqual(descr.summary_description, "my description")
151
152      # If both SummaryMetadata and explicit args are provided, the args win
153      overwrite = summary_lib.tensor_summary(
154          "simple",
155          const,
156          summary_metadata=metadata,
157          display_name="overwritten",
158          summary_description="overwritten")
159      descr = get_description(overwrite)
160      self.assertEqual(descr.display_name, "overwritten")
161      self.assertEqual(descr.summary_description, "overwritten")
162
163
164if __name__ == "__main__":
165  test.main()
166