1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for summary V1 image op.""" 16 17import numpy as np 18 19from tensorflow.core.framework import summary_pb2 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import image_ops 24import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 25from tensorflow.python.platform import test 26from tensorflow.python.summary import summary 27 28 29class SummaryV1ImageOpTest(test.TestCase): 30 31 def _AsSummary(self, s): 32 summ = summary_pb2.Summary() 33 summ.ParseFromString(s) 34 return summ 35 36 def _CheckProto(self, image_summ, shape): 37 """Verify that the non-image parts of the image_summ proto match shape.""" 38 # Only the first 3 images are returned. 39 for v in image_summ.value: 40 v.image.ClearField("encoded_image_string") 41 expected = "\n".join(""" 42 value { 43 tag: "img/image/%d" 44 image { height: %d width: %d colorspace: %d } 45 }""" % ((i,) + shape[1:]) for i in range(3)) 46 self.assertProtoEquals(expected, image_summ) 47 48 @test_util.run_deprecated_v1 49 def testImageSummary(self): 50 for depth in (1, 3, 4): 51 for positive in False, True: 52 with self.session(graph=ops.Graph()) as sess: 53 shape = (4, 5, 7) + (depth,) 54 bad_color = [255, 0, 0, 255][:depth] 55 # Build a mostly random image with one nan 56 const = np.random.randn(*shape).astype(np.float32) 57 const[0, 1, 2] = 0 # Make the nan entry not the max 58 if positive: 59 const = 1 + np.maximum(const, 0) 60 scale = 255 / const.reshape(4, -1).max(axis=1) 61 offset = 0 62 else: 63 scale = 127 / np.abs(const.reshape(4, -1)).max(axis=1) 64 offset = 128 65 adjusted = np.floor(scale[:, None, None, None] * const + offset) 66 const[0, 1, 2, depth // 2] = np.nan 67 68 # Summarize 69 summ = summary.image("img", const) 70 value = self.evaluate(summ) 71 self.assertEqual([], summ.get_shape()) 72 image_summ = self._AsSummary(value) 73 74 # Decode the first image and check consistency 75 image = image_ops.decode_png(image_summ.value[0] 76 .image.encoded_image_string).eval() 77 self.assertAllEqual(image[1, 2], bad_color) 78 image[1, 2] = adjusted[0, 1, 2] 79 self.assertAllClose(image, adjusted[0], rtol=2e-5, atol=2e-5) 80 81 # Check the rest of the proto 82 self._CheckProto(image_summ, shape) 83 84 @test_util.run_deprecated_v1 85 def testImageSummaryUint8(self): 86 np.random.seed(7) 87 for depth in (1, 3, 4): 88 with self.session(graph=ops.Graph()) as sess: 89 shape = (4, 5, 7) + (depth,) 90 91 # Build a random uint8 image 92 images = np.random.randint(256, size=shape).astype(np.uint8) 93 tf_images = ops.convert_to_tensor(images) 94 self.assertEqual(tf_images.dtype, dtypes.uint8) 95 96 # Summarize 97 summ = summary.image("img", tf_images) 98 value = self.evaluate(summ) 99 self.assertEqual([], summ.get_shape()) 100 image_summ = self._AsSummary(value) 101 102 # Decode the first image and check consistency. 103 # Since we're uint8, everything should be exact. 104 image = image_ops.decode_png(image_summ.value[0] 105 .image.encoded_image_string).eval() 106 self.assertAllEqual(image, images[0]) 107 108 # Check the rest of the proto 109 self._CheckProto(image_summ, shape) 110 111 112if __name__ == "__main__": 113 test.main() 114