• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the API surface of the V1 tf.summary ops.
16
17These tests don't check the actual serialized proto summary value for the
18more complex summaries (e.g. audio, image).  Those test live separately in
19tensorflow/python/kernel_tests/summary_v1_*.py.
20"""
21
22
23from tensorflow.core.framework import summary_pb2
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import meta_graph
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import test
33from tensorflow.python.summary import summary as summary_lib
34
35
36class SummaryTest(test.TestCase):
37
38  @test_util.run_deprecated_v1
39  def testScalarSummary(self):
40    with self.cached_session() as s:
41      i = constant_op.constant(3)
42      with ops.name_scope('outer'):
43        im = summary_lib.scalar('inner', i)
44      summary_str = s.run(im)
45    summary = summary_pb2.Summary()
46    summary.ParseFromString(summary_str)
47    values = summary.value
48    self.assertEqual(len(values), 1)
49    self.assertEqual(values[0].tag, 'outer/inner')
50    self.assertEqual(values[0].simple_value, 3.0)
51
52  @test_util.run_deprecated_v1
53  def testScalarSummaryWithFamily(self):
54    with self.cached_session() as s:
55      i = constant_op.constant(7)
56      with ops.name_scope('outer'):
57        im1 = summary_lib.scalar('inner', i, family='family')
58        self.assertEqual(im1.op.name, 'outer/family/inner')
59        im2 = summary_lib.scalar('inner', i, family='family')
60        self.assertEqual(im2.op.name, 'outer/family/inner_1')
61      sm1, sm2 = s.run([im1, im2])
62    summary = summary_pb2.Summary()
63
64    summary.ParseFromString(sm1)
65    values = summary.value
66    self.assertEqual(len(values), 1)
67    self.assertEqual(values[0].tag, 'family/outer/family/inner')
68    self.assertEqual(values[0].simple_value, 7.0)
69
70    summary.ParseFromString(sm2)
71    values = summary.value
72    self.assertEqual(len(values), 1)
73    self.assertEqual(values[0].tag, 'family/outer/family/inner_1')
74    self.assertEqual(values[0].simple_value, 7.0)
75
76  @test_util.run_deprecated_v1
77  def testSummarizingVariable(self):
78    with self.cached_session() as s:
79      c = constant_op.constant(42.0)
80      v = variables.Variable(c)
81      ss = summary_lib.scalar('summary', v)
82      init = variables.global_variables_initializer()
83      s.run(init)
84      summ_str = s.run(ss)
85    summary = summary_pb2.Summary()
86    summary.ParseFromString(summ_str)
87    self.assertEqual(len(summary.value), 1)
88    value = summary.value[0]
89    self.assertEqual(value.tag, 'summary')
90    self.assertEqual(value.simple_value, 42.0)
91
92  @test_util.run_deprecated_v1
93  def testImageSummary(self):
94    with self.cached_session() as s:
95      i = array_ops.ones((5, 4, 4, 3))
96      with ops.name_scope('outer'):
97        im = summary_lib.image('inner', i, max_outputs=3)
98      summary_str = s.run(im)
99    summary = summary_pb2.Summary()
100    summary.ParseFromString(summary_str)
101    values = summary.value
102    self.assertEqual(len(values), 3)
103    tags = sorted(v.tag for v in values)
104    expected = sorted('outer/inner/image/{}'.format(i) for i in range(3))
105    self.assertEqual(tags, expected)
106
107  @test_util.run_deprecated_v1
108  def testImageSummaryWithFamily(self):
109    with self.cached_session() as s:
110      i = array_ops.ones((5, 2, 3, 1))
111      with ops.name_scope('outer'):
112        im = summary_lib.image('inner', i, max_outputs=3, family='family')
113        self.assertEqual(im.op.name, 'outer/family/inner')
114      summary_str = s.run(im)
115    summary = summary_pb2.Summary()
116    summary.ParseFromString(summary_str)
117    values = summary.value
118    self.assertEqual(len(values), 3)
119    tags = sorted(v.tag for v in values)
120    expected = sorted(
121        'family/outer/family/inner/image/{}'.format(i) for i in range(3))
122    self.assertEqual(tags, expected)
123
124  @test_util.run_deprecated_v1
125  def testHistogramSummary(self):
126    with self.cached_session() as s:
127      i = array_ops.ones((5, 4, 4, 3))
128      with ops.name_scope('outer'):
129        summ_op = summary_lib.histogram('inner', i)
130      summary_str = s.run(summ_op)
131    summary = summary_pb2.Summary()
132    summary.ParseFromString(summary_str)
133    self.assertEqual(len(summary.value), 1)
134    self.assertEqual(summary.value[0].tag, 'outer/inner')
135
136  @test_util.run_deprecated_v1
137  def testHistogramSummaryWithFamily(self):
138    with self.cached_session() as s:
139      i = array_ops.ones((5, 4, 4, 3))
140      with ops.name_scope('outer'):
141        summ_op = summary_lib.histogram('inner', i, family='family')
142        self.assertEqual(summ_op.op.name, 'outer/family/inner')
143      summary_str = s.run(summ_op)
144    summary = summary_pb2.Summary()
145    summary.ParseFromString(summary_str)
146    self.assertEqual(len(summary.value), 1)
147    self.assertEqual(summary.value[0].tag, 'family/outer/family/inner')
148
149  def testHistogramSummaryTypes(self):
150    for dtype in (dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.int32,
151                  dtypes.float32, dtypes.float64):
152      const = constant_op.constant(10, dtype=dtype)
153      summary_lib.histogram('h', const)
154
155  @test_util.run_deprecated_v1
156  def testAudioSummary(self):
157    with self.cached_session() as s:
158      i = array_ops.ones((5, 3, 4))
159      with ops.name_scope('outer'):
160        aud = summary_lib.audio('inner', i, 0.2, max_outputs=3)
161      summary_str = s.run(aud)
162    summary = summary_pb2.Summary()
163    summary.ParseFromString(summary_str)
164    values = summary.value
165    self.assertEqual(len(values), 3)
166    tags = sorted(v.tag for v in values)
167    expected = sorted('outer/inner/audio/{}'.format(i) for i in range(3))
168    self.assertEqual(tags, expected)
169
170  @test_util.run_deprecated_v1
171  def testAudioSummaryWithFamily(self):
172    with self.cached_session() as s:
173      i = array_ops.ones((5, 3, 4))
174      with ops.name_scope('outer'):
175        aud = summary_lib.audio('inner', i, 0.2, max_outputs=3, family='family')
176        self.assertEqual(aud.op.name, 'outer/family/inner')
177      summary_str = s.run(aud)
178    summary = summary_pb2.Summary()
179    summary.ParseFromString(summary_str)
180    values = summary.value
181    self.assertEqual(len(values), 3)
182    tags = sorted(v.tag for v in values)
183    expected = sorted(
184        'family/outer/family/inner/audio/{}'.format(i) for i in range(3))
185    self.assertEqual(tags, expected)
186
187  def testAudioSummaryWithInvalidSampleRate(self):
188    with self.assertRaises(errors.InvalidArgumentError):
189      invalid_sample_rate = [22000.0, 22000.0]
190      self.evaluate(summary_lib.audio('', [[1.0]], invalid_sample_rate))
191
192  @test_util.run_deprecated_v1
193  def testTextSummary(self):
194    with self.cached_session():
195      with self.assertRaises(ValueError):
196        num = array_ops.constant(1)
197        summary_lib.text('foo', num)
198
199      # The API accepts vectors.
200      arr = array_ops.constant(['one', 'two', 'three'])
201      summ = summary_lib.text('foo', arr)
202      self.assertEqual(summ.op.type, 'TensorSummaryV2')
203
204      # the API accepts scalars
205      summ = summary_lib.text('foo', array_ops.constant('one'))
206      self.assertEqual(summ.op.type, 'TensorSummaryV2')
207
208  @test_util.run_deprecated_v1
209  def testSummaryNameConversion(self):
210    c = constant_op.constant(3)
211    s = summary_lib.scalar('name with spaces', c)
212    self.assertEqual(s.op.name, 'name_with_spaces')
213
214    s2 = summary_lib.scalar('name with many $#illegal^: characters!', c)
215    self.assertEqual(s2.op.name, 'name_with_many___illegal___characters_')
216
217    s3 = summary_lib.scalar('/name/with/leading/slash', c)
218    self.assertEqual(s3.op.name, 'name/with/leading/slash')
219
220  @test_util.run_deprecated_v1
221  def testSummaryWithFamilyMetaGraphExport(self):
222    with ops.name_scope('outer'):
223      i = constant_op.constant(11)
224      summ = summary_lib.scalar('inner', i)
225      self.assertEqual(summ.op.name, 'outer/inner')
226      summ_f = summary_lib.scalar('inner', i, family='family')
227      self.assertEqual(summ_f.op.name, 'outer/family/inner')
228
229    metagraph_def, _ = meta_graph.export_scoped_meta_graph(export_scope='outer')
230
231    with ops.Graph().as_default() as g:
232      meta_graph.import_scoped_meta_graph(metagraph_def, graph=g,
233                                          import_scope='new_outer')
234      # The summaries should exist, but with outer scope renamed.
235      new_summ = g.get_tensor_by_name('new_outer/inner:0')
236      new_summ_f = g.get_tensor_by_name('new_outer/family/inner:0')
237
238      # However, the tags are unaffected.
239      with self.cached_session() as s:
240        new_summ_str, new_summ_f_str = s.run([new_summ, new_summ_f])
241        new_summ_pb = summary_pb2.Summary()
242        new_summ_pb.ParseFromString(new_summ_str)
243        self.assertEqual('outer/inner', new_summ_pb.value[0].tag)
244        new_summ_f_pb = summary_pb2.Summary()
245        new_summ_f_pb.ParseFromString(new_summ_f_str)
246        self.assertEqual('family/outer/family/inner',
247                         new_summ_f_pb.value[0].tag)
248
249
250if __name__ == '__main__':
251  test.main()
252