1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the API surface of the V1 tf.summary ops. 16 17These tests don't check the actual serialized proto summary value for the 18more complex summaries (e.g. audio, image). Those test live separately in 19tensorflow/python/kernel_tests/summary_v1_*.py. 20""" 21 22 23from tensorflow.core.framework import summary_pb2 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import meta_graph 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import test 33from tensorflow.python.summary import summary as summary_lib 34 35 36class SummaryTest(test.TestCase): 37 38 @test_util.run_deprecated_v1 39 def testScalarSummary(self): 40 with self.cached_session() as s: 41 i = constant_op.constant(3) 42 with ops.name_scope('outer'): 43 im = summary_lib.scalar('inner', i) 44 summary_str = s.run(im) 45 summary = summary_pb2.Summary() 46 summary.ParseFromString(summary_str) 47 values = summary.value 48 self.assertEqual(len(values), 1) 49 self.assertEqual(values[0].tag, 'outer/inner') 50 self.assertEqual(values[0].simple_value, 3.0) 51 52 @test_util.run_deprecated_v1 53 def testScalarSummaryWithFamily(self): 54 with self.cached_session() as s: 55 i = constant_op.constant(7) 56 with ops.name_scope('outer'): 57 im1 = summary_lib.scalar('inner', i, family='family') 58 self.assertEqual(im1.op.name, 'outer/family/inner') 59 im2 = summary_lib.scalar('inner', i, family='family') 60 self.assertEqual(im2.op.name, 'outer/family/inner_1') 61 sm1, sm2 = s.run([im1, im2]) 62 summary = summary_pb2.Summary() 63 64 summary.ParseFromString(sm1) 65 values = summary.value 66 self.assertEqual(len(values), 1) 67 self.assertEqual(values[0].tag, 'family/outer/family/inner') 68 self.assertEqual(values[0].simple_value, 7.0) 69 70 summary.ParseFromString(sm2) 71 values = summary.value 72 self.assertEqual(len(values), 1) 73 self.assertEqual(values[0].tag, 'family/outer/family/inner_1') 74 self.assertEqual(values[0].simple_value, 7.0) 75 76 @test_util.run_deprecated_v1 77 def testSummarizingVariable(self): 78 with self.cached_session() as s: 79 c = constant_op.constant(42.0) 80 v = variables.Variable(c) 81 ss = summary_lib.scalar('summary', v) 82 init = variables.global_variables_initializer() 83 s.run(init) 84 summ_str = s.run(ss) 85 summary = summary_pb2.Summary() 86 summary.ParseFromString(summ_str) 87 self.assertEqual(len(summary.value), 1) 88 value = summary.value[0] 89 self.assertEqual(value.tag, 'summary') 90 self.assertEqual(value.simple_value, 42.0) 91 92 @test_util.run_deprecated_v1 93 def testImageSummary(self): 94 with self.cached_session() as s: 95 i = array_ops.ones((5, 4, 4, 3)) 96 with ops.name_scope('outer'): 97 im = summary_lib.image('inner', i, max_outputs=3) 98 summary_str = s.run(im) 99 summary = summary_pb2.Summary() 100 summary.ParseFromString(summary_str) 101 values = summary.value 102 self.assertEqual(len(values), 3) 103 tags = sorted(v.tag for v in values) 104 expected = sorted('outer/inner/image/{}'.format(i) for i in range(3)) 105 self.assertEqual(tags, expected) 106 107 @test_util.run_deprecated_v1 108 def testImageSummaryWithFamily(self): 109 with self.cached_session() as s: 110 i = array_ops.ones((5, 2, 3, 1)) 111 with ops.name_scope('outer'): 112 im = summary_lib.image('inner', i, max_outputs=3, family='family') 113 self.assertEqual(im.op.name, 'outer/family/inner') 114 summary_str = s.run(im) 115 summary = summary_pb2.Summary() 116 summary.ParseFromString(summary_str) 117 values = summary.value 118 self.assertEqual(len(values), 3) 119 tags = sorted(v.tag for v in values) 120 expected = sorted( 121 'family/outer/family/inner/image/{}'.format(i) for i in range(3)) 122 self.assertEqual(tags, expected) 123 124 @test_util.run_deprecated_v1 125 def testHistogramSummary(self): 126 with self.cached_session() as s: 127 i = array_ops.ones((5, 4, 4, 3)) 128 with ops.name_scope('outer'): 129 summ_op = summary_lib.histogram('inner', i) 130 summary_str = s.run(summ_op) 131 summary = summary_pb2.Summary() 132 summary.ParseFromString(summary_str) 133 self.assertEqual(len(summary.value), 1) 134 self.assertEqual(summary.value[0].tag, 'outer/inner') 135 136 @test_util.run_deprecated_v1 137 def testHistogramSummaryWithFamily(self): 138 with self.cached_session() as s: 139 i = array_ops.ones((5, 4, 4, 3)) 140 with ops.name_scope('outer'): 141 summ_op = summary_lib.histogram('inner', i, family='family') 142 self.assertEqual(summ_op.op.name, 'outer/family/inner') 143 summary_str = s.run(summ_op) 144 summary = summary_pb2.Summary() 145 summary.ParseFromString(summary_str) 146 self.assertEqual(len(summary.value), 1) 147 self.assertEqual(summary.value[0].tag, 'family/outer/family/inner') 148 149 def testHistogramSummaryTypes(self): 150 for dtype in (dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.int32, 151 dtypes.float32, dtypes.float64): 152 const = constant_op.constant(10, dtype=dtype) 153 summary_lib.histogram('h', const) 154 155 @test_util.run_deprecated_v1 156 def testAudioSummary(self): 157 with self.cached_session() as s: 158 i = array_ops.ones((5, 3, 4)) 159 with ops.name_scope('outer'): 160 aud = summary_lib.audio('inner', i, 0.2, max_outputs=3) 161 summary_str = s.run(aud) 162 summary = summary_pb2.Summary() 163 summary.ParseFromString(summary_str) 164 values = summary.value 165 self.assertEqual(len(values), 3) 166 tags = sorted(v.tag for v in values) 167 expected = sorted('outer/inner/audio/{}'.format(i) for i in range(3)) 168 self.assertEqual(tags, expected) 169 170 @test_util.run_deprecated_v1 171 def testAudioSummaryWithFamily(self): 172 with self.cached_session() as s: 173 i = array_ops.ones((5, 3, 4)) 174 with ops.name_scope('outer'): 175 aud = summary_lib.audio('inner', i, 0.2, max_outputs=3, family='family') 176 self.assertEqual(aud.op.name, 'outer/family/inner') 177 summary_str = s.run(aud) 178 summary = summary_pb2.Summary() 179 summary.ParseFromString(summary_str) 180 values = summary.value 181 self.assertEqual(len(values), 3) 182 tags = sorted(v.tag for v in values) 183 expected = sorted( 184 'family/outer/family/inner/audio/{}'.format(i) for i in range(3)) 185 self.assertEqual(tags, expected) 186 187 def testAudioSummaryWithInvalidSampleRate(self): 188 with self.assertRaises(errors.InvalidArgumentError): 189 invalid_sample_rate = [22000.0, 22000.0] 190 self.evaluate(summary_lib.audio('', [[1.0]], invalid_sample_rate)) 191 192 @test_util.run_deprecated_v1 193 def testTextSummary(self): 194 with self.cached_session(): 195 with self.assertRaises(ValueError): 196 num = array_ops.constant(1) 197 summary_lib.text('foo', num) 198 199 # The API accepts vectors. 200 arr = array_ops.constant(['one', 'two', 'three']) 201 summ = summary_lib.text('foo', arr) 202 self.assertEqual(summ.op.type, 'TensorSummaryV2') 203 204 # the API accepts scalars 205 summ = summary_lib.text('foo', array_ops.constant('one')) 206 self.assertEqual(summ.op.type, 'TensorSummaryV2') 207 208 @test_util.run_deprecated_v1 209 def testSummaryNameConversion(self): 210 c = constant_op.constant(3) 211 s = summary_lib.scalar('name with spaces', c) 212 self.assertEqual(s.op.name, 'name_with_spaces') 213 214 s2 = summary_lib.scalar('name with many $#illegal^: characters!', c) 215 self.assertEqual(s2.op.name, 'name_with_many___illegal___characters_') 216 217 s3 = summary_lib.scalar('/name/with/leading/slash', c) 218 self.assertEqual(s3.op.name, 'name/with/leading/slash') 219 220 @test_util.run_deprecated_v1 221 def testSummaryWithFamilyMetaGraphExport(self): 222 with ops.name_scope('outer'): 223 i = constant_op.constant(11) 224 summ = summary_lib.scalar('inner', i) 225 self.assertEqual(summ.op.name, 'outer/inner') 226 summ_f = summary_lib.scalar('inner', i, family='family') 227 self.assertEqual(summ_f.op.name, 'outer/family/inner') 228 229 metagraph_def, _ = meta_graph.export_scoped_meta_graph(export_scope='outer') 230 231 with ops.Graph().as_default() as g: 232 meta_graph.import_scoped_meta_graph(metagraph_def, graph=g, 233 import_scope='new_outer') 234 # The summaries should exist, but with outer scope renamed. 235 new_summ = g.get_tensor_by_name('new_outer/inner:0') 236 new_summ_f = g.get_tensor_by_name('new_outer/family/inner:0') 237 238 # However, the tags are unaffected. 239 with self.cached_session() as s: 240 new_summ_str, new_summ_f_str = s.run([new_summ, new_summ_f]) 241 new_summ_pb = summary_pb2.Summary() 242 new_summ_pb.ParseFromString(new_summ_str) 243 self.assertEqual('outer/inner', new_summ_pb.value[0].tag) 244 new_summ_f_pb = summary_pb2.Summary() 245 new_summ_f_pb.ParseFromString(new_summ_f_str) 246 self.assertEqual('family/outer/family/inner', 247 new_summ_f_pb.value[0].tag) 248 249 250if __name__ == '__main__': 251 test.main() 252