• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for summary V1 image op."""
16
17import numpy as np
18
19from tensorflow.core.framework import summary_pb2
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import image_ops
24import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
25from tensorflow.python.platform import test
26from tensorflow.python.summary import summary
27
28
29class SummaryV1ImageOpTest(test.TestCase):
30
31  def _AsSummary(self, s):
32    summ = summary_pb2.Summary()
33    summ.ParseFromString(s)
34    return summ
35
36  def _CheckProto(self, image_summ, shape):
37    """Verify that the non-image parts of the image_summ proto match shape."""
38    # Only the first 3 images are returned.
39    for v in image_summ.value:
40      v.image.ClearField("encoded_image_string")
41    expected = "\n".join("""
42        value {
43          tag: "img/image/%d"
44          image { height: %d width: %d colorspace: %d }
45        }""" % ((i,) + shape[1:]) for i in range(3))
46    self.assertProtoEquals(expected, image_summ)
47
48  @test_util.run_deprecated_v1
49  def testImageSummary(self):
50    for depth in (1, 3, 4):
51      for positive in False, True:
52        with self.session(graph=ops.Graph()) as sess:
53          shape = (4, 5, 7) + (depth,)
54          bad_color = [255, 0, 0, 255][:depth]
55          # Build a mostly random image with one nan
56          const = np.random.randn(*shape).astype(np.float32)
57          const[0, 1, 2] = 0  # Make the nan entry not the max
58          if positive:
59            const = 1 + np.maximum(const, 0)
60            scale = 255 / const.reshape(4, -1).max(axis=1)
61            offset = 0
62          else:
63            scale = 127 / np.abs(const.reshape(4, -1)).max(axis=1)
64            offset = 128
65          adjusted = np.floor(scale[:, None, None, None] * const + offset)
66          const[0, 1, 2, depth // 2] = np.nan
67
68          # Summarize
69          summ = summary.image("img", const)
70          value = self.evaluate(summ)
71          self.assertEqual([], summ.get_shape())
72          image_summ = self._AsSummary(value)
73
74          # Decode the first image and check consistency
75          image = image_ops.decode_png(image_summ.value[0]
76                                       .image.encoded_image_string).eval()
77          self.assertAllEqual(image[1, 2], bad_color)
78          image[1, 2] = adjusted[0, 1, 2]
79          self.assertAllClose(image, adjusted[0], rtol=2e-5, atol=2e-5)
80
81          # Check the rest of the proto
82          self._CheckProto(image_summ, shape)
83
84  @test_util.run_deprecated_v1
85  def testImageSummaryUint8(self):
86    np.random.seed(7)
87    for depth in (1, 3, 4):
88      with self.session(graph=ops.Graph()) as sess:
89        shape = (4, 5, 7) + (depth,)
90
91        # Build a random uint8 image
92        images = np.random.randint(256, size=shape).astype(np.uint8)
93        tf_images = ops.convert_to_tensor(images)
94        self.assertEqual(tf_images.dtype, dtypes.uint8)
95
96        # Summarize
97        summ = summary.image("img", tf_images)
98        value = self.evaluate(summ)
99        self.assertEqual([], summ.get_shape())
100        image_summ = self._AsSummary(value)
101
102        # Decode the first image and check consistency.
103        # Since we're uint8, everything should be exact.
104        image = image_ops.decode_png(image_summ.value[0]
105                                     .image.encoded_image_string).eval()
106        self.assertAllEqual(image, images[0])
107
108        # Check the rest of the proto
109        self._CheckProto(image_summ, shape)
110
111
112if __name__ == "__main__":
113  test.main()
114