1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the actual serialized proto output of the V1 tf.summary ops. 16 17The tensor, audio, and image ops have dedicated tests in adjacent files. The 18overall tf.summary API surface also has its own tests in summary_test.py that 19check calling the API methods but not the exact serialized proto output. 20""" 21 22from tensorflow.core.framework import summary_pb2 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import logging_ops 27from tensorflow.python.platform import test 28from tensorflow.python.summary import summary 29 30 31class SummaryV1OpsTest(test.TestCase): 32 33 def _AsSummary(self, s): 34 summ = summary_pb2.Summary() 35 summ.ParseFromString(s) 36 return summ 37 38 def testScalarSummary(self): 39 with self.cached_session() as sess: 40 const = constant_op.constant([10.0, 20.0]) 41 summ = logging_ops.scalar_summary(["c1", "c2"], const, name="mysumm") 42 value = self.evaluate(summ) 43 self.assertEqual([], summ.get_shape()) 44 self.assertProtoEquals(""" 45 value { tag: "c1" simple_value: 10.0 } 46 value { tag: "c2" simple_value: 20.0 } 47 """, self._AsSummary(value)) 48 49 def testScalarSummaryDefaultName(self): 50 with self.cached_session() as sess: 51 const = constant_op.constant([10.0, 20.0]) 52 summ = logging_ops.scalar_summary(["c1", "c2"], const) 53 value = self.evaluate(summ) 54 self.assertEqual([], summ.get_shape()) 55 self.assertProtoEquals(""" 56 value { tag: "c1" simple_value: 10.0 } 57 value { tag: "c2" simple_value: 20.0 } 58 """, self._AsSummary(value)) 59 60 @test_util.run_deprecated_v1 61 def testMergeSummary(self): 62 with self.cached_session() as sess: 63 const = constant_op.constant(10.0) 64 summ1 = summary.histogram("h", const) 65 summ2 = logging_ops.scalar_summary("c", const) 66 merge = summary.merge([summ1, summ2]) 67 value = self.evaluate(merge) 68 self.assertEqual([], merge.get_shape()) 69 self.assertProtoEquals(""" 70 value { 71 tag: "h" 72 histo { 73 min: 10.0 74 max: 10.0 75 num: 1.0 76 sum: 10.0 77 sum_squares: 100.0 78 bucket_limit: 9.93809490288 79 bucket_limit: 10.9319043932 80 bucket_limit: 1.7976931348623157e+308 81 bucket: 0.0 82 bucket: 1.0 83 bucket: 0.0 84 } 85 } 86 value { tag: "c" simple_value: 10.0 } 87 """, self._AsSummary(value)) 88 89 def testMergeAllSummaries(self): 90 with ops.Graph().as_default(): 91 const = constant_op.constant(10.0) 92 summ1 = summary.histogram("h", const) 93 summ2 = summary.scalar("o", const, collections=["foo_key"]) 94 summ3 = summary.scalar("c", const) 95 merge = summary.merge_all() 96 self.assertEqual("MergeSummary", merge.op.type) 97 self.assertEqual(2, len(merge.op.inputs)) 98 self.assertEqual(summ1, merge.op.inputs[0]) 99 self.assertEqual(summ3, merge.op.inputs[1]) 100 merge = summary.merge_all("foo_key") 101 self.assertEqual("MergeSummary", merge.op.type) 102 self.assertEqual(1, len(merge.op.inputs)) 103 self.assertEqual(summ2, merge.op.inputs[0]) 104 self.assertTrue(summary.merge_all("bar_key") is None) 105 106 107if __name__ == "__main__": 108 test.main() 109