1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Functions to read and write images in PNG format.
17
18 #include <string.h>
19 #include <sys/types.h>
20 #include <zlib.h>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 // NOTE(skal): we don't '#include <setjmp.h>' before png.h as it otherwise
25 // provokes a compile error. We instead let png.h include what is needed.
26
27 #include "absl/base/casts.h"
28 #include "tensorflow/core/lib/png/png_io.h"
29 #include "tensorflow/core/platform/byte_order.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/png.h"
32
33 namespace tensorflow {
34 namespace png {
35
36 ////////////////////////////////////////////////////////////////////////////////
37 // Encode an 8- or 16-bit rgb/grayscale image to PNG string
38 ////////////////////////////////////////////////////////////////////////////////
39
40 namespace {
41
42 #define PTR_INC(type, ptr, del) \
43 (ptr = reinterpret_cast<type*>(reinterpret_cast<char*>(ptr) + (del)))
44 #define CPTR_INC(type, ptr, del) \
45 (ptr = reinterpret_cast<const type*>(reinterpret_cast<const char*>(ptr) + \
46 (del)))
47
48 // Convert from 8 bit components to 16. This works in-place.
Convert8to16(const uint8 * p8,int num_comps,int p8_row_bytes,int width,int height_in,uint16 * p16,int p16_row_bytes)49 static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes,
50 int width, int height_in, uint16* p16,
51 int p16_row_bytes) {
52 // Force height*row_bytes computations to use 64 bits. Height*width is
53 // enforced to < 29 bits in decode_png_op.cc, but height*row_bytes is
54 // height*width*channels*(8bit?1:2) which is therefore only constrained to <
55 // 33 bits.
56 int64 height = static_cast<int64>(height_in);
57
58 // Adjust pointers to copy backwards
59 width *= num_comps;
60 CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes + (width - 1) * sizeof(*p8));
61 PTR_INC(uint16, p16,
62 (height - 1) * p16_row_bytes + (width - 1) * sizeof(*p16));
63 int bump8 = width * sizeof(*p8) - p8_row_bytes;
64 int bump16 = width * sizeof(*p16) - p16_row_bytes;
65 for (; height-- != 0;
66 CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) {
67 for (int w = width; w-- != 0; --p8, --p16) {
68 uint32 pix = *p8;
69 pix |= pix << 8;
70 *p16 = static_cast<uint16>(pix);
71 }
72 }
73 }
74
75 #undef PTR_INC
76 #undef CPTR_INC
77
ErrorHandler(png_structp png_ptr,png_const_charp msg)78 void ErrorHandler(png_structp png_ptr, png_const_charp msg) {
79 DecodeContext* const ctx =
80 absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
81 ctx->error_condition = true;
82 // To prevent log spam, errors are logged as VLOG(1) instead of ERROR.
83 VLOG(1) << "PNG error: " << msg;
84 longjmp(png_jmpbuf(png_ptr), 1);
85 }
86
WarningHandler(png_structp png_ptr,png_const_charp msg)87 void WarningHandler(png_structp png_ptr, png_const_charp msg) {
88 LOG(WARNING) << "PNG warning: " << msg;
89 }
90
StringReader(png_structp png_ptr,png_bytep data,png_size_t length)91 void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) {
92 DecodeContext* const ctx =
93 absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
94 if (static_cast<png_size_t>(ctx->data_left) < length) {
95 // Don't zero out the data buffer as it has been lazily allocated (copy on
96 // write) and zeroing it out here can produce an OOM. Since the buffer is
97 // only used for reading data from the image, this doesn't result in any
98 // data leak, so it is safe to just leave the buffer be as it is and just
99 // exit with error.
100 png_error(png_ptr, "More bytes requested to read than available");
101 } else {
102 memcpy(data, ctx->data, length);
103 ctx->data += length;
104 ctx->data_left -= length;
105 }
106 }
107
108 template <typename T>
StringWriter(png_structp png_ptr,png_bytep data,png_size_t length)109 void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) {
110 T* const s = absl::bit_cast<T*>(png_get_io_ptr(png_ptr));
111 s->append(absl::bit_cast<const char*>(data), length);
112 }
113
StringWriterFlush(png_structp png_ptr)114 void StringWriterFlush(png_structp png_ptr) {}
115
check_metadata_string(const string & s)116 char* check_metadata_string(const string& s) {
117 const char* const c_string = s.c_str();
118 const size_t length = s.size();
119 if (strlen(c_string) != length) {
120 LOG(WARNING) << "Warning! Metadata contains \\0 character(s).";
121 }
122 return const_cast<char*>(c_string);
123 }
124
125 } // namespace
126
127 // We move CommonInitDecode() and CommonFinishDecode()
128 // out of the CommonDecode() template to save code space.
CommonFreeDecode(DecodeContext * context)129 void CommonFreeDecode(DecodeContext* context) {
130 if (context->png_ptr) {
131 png_destroy_read_struct(&context->png_ptr,
132 context->info_ptr ? &context->info_ptr : nullptr,
133 nullptr);
134 context->png_ptr = nullptr;
135 context->info_ptr = nullptr;
136 }
137 }
138
DecodeHeader(StringPiece png_string,int * width,int * height,int * components,int * channel_bit_depth,std::vector<std::pair<string,string>> * metadata)139 bool DecodeHeader(StringPiece png_string, int* width, int* height,
140 int* components, int* channel_bit_depth,
141 std::vector<std::pair<string, string> >* metadata) {
142 DecodeContext context;
143 // Ask for 16 bits even if there may be fewer. This assures that sniffing
144 // the metadata will succeed in all cases.
145 //
146 // TODO(skal): CommonInitDecode() mixes the operation of sniffing the
147 // metadata with setting up the data conversions. These should be separated.
148 constexpr int kDesiredNumChannels = 1;
149 constexpr int kDesiredChannelBits = 16;
150 if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits,
151 &context)) {
152 return false;
153 }
154 CHECK_NOTNULL(width);
155 *width = static_cast<int>(context.width);
156 CHECK_NOTNULL(height);
157 *height = static_cast<int>(context.height);
158 if (components != nullptr) {
159 switch (context.color_type) {
160 case PNG_COLOR_TYPE_PALETTE:
161 *components =
162 (png_get_valid(context.png_ptr, context.info_ptr, PNG_INFO_tRNS))
163 ? 4
164 : 3;
165 break;
166 case PNG_COLOR_TYPE_GRAY:
167 *components = 1;
168 break;
169 case PNG_COLOR_TYPE_GRAY_ALPHA:
170 *components = 2;
171 break;
172 case PNG_COLOR_TYPE_RGB:
173 *components = 3;
174 break;
175 case PNG_COLOR_TYPE_RGB_ALPHA:
176 *components = 4;
177 break;
178 default:
179 *components = 0;
180 break;
181 }
182 }
183 if (channel_bit_depth != nullptr) {
184 *channel_bit_depth = context.bit_depth;
185 }
186 if (metadata != nullptr) {
187 metadata->clear();
188 png_textp text_ptr = nullptr;
189 int num_text = 0;
190 png_get_text(context.png_ptr, context.info_ptr, &text_ptr, &num_text);
191 for (int i = 0; i < num_text; i++) {
192 const png_text& text = text_ptr[i];
193 metadata->push_back(std::make_pair(text.key, text.text));
194 }
195 }
196 CommonFreeDecode(&context);
197 return true;
198 }
199
CommonInitDecode(StringPiece png_string,int desired_channels,int desired_channel_bits,DecodeContext * context)200 bool CommonInitDecode(StringPiece png_string, int desired_channels,
201 int desired_channel_bits, DecodeContext* context) {
202 CHECK(desired_channel_bits == 8 || desired_channel_bits == 16)
203 << "desired_channel_bits = " << desired_channel_bits;
204 CHECK(0 <= desired_channels && desired_channels <= 4)
205 << "desired_channels = " << desired_channels;
206 context->error_condition = false;
207 context->channels = desired_channels;
208 context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context,
209 ErrorHandler, WarningHandler);
210 if (!context->png_ptr) {
211 VLOG(1) << ": DecodePNG <- png_create_read_struct failed";
212 return false;
213 }
214 if (setjmp(png_jmpbuf(context->png_ptr))) {
215 VLOG(1) << ": DecodePNG error trapped.";
216 CommonFreeDecode(context);
217 return false;
218 }
219 context->info_ptr = png_create_info_struct(context->png_ptr);
220 if (!context->info_ptr || context->error_condition) {
221 VLOG(1) << ": DecodePNG <- png_create_info_struct failed";
222 CommonFreeDecode(context);
223 return false;
224 }
225 context->data = absl::bit_cast<const uint8*>(png_string.data());
226 context->data_left = png_string.size();
227 png_set_read_fn(context->png_ptr, context, StringReader);
228 png_read_info(context->png_ptr, context->info_ptr);
229 png_get_IHDR(context->png_ptr, context->info_ptr, &context->width,
230 &context->height, &context->bit_depth, &context->color_type,
231 nullptr, nullptr, nullptr);
232 if (context->error_condition) {
233 VLOG(1) << ": DecodePNG <- error during header parsing.";
234 CommonFreeDecode(context);
235 return false;
236 }
237 if (context->width <= 0 || context->height <= 0) {
238 VLOG(1) << ": DecodePNG <- invalid dimensions";
239 CommonFreeDecode(context);
240 return false;
241 }
242 const bool has_tRNS =
243 (png_get_valid(context->png_ptr, context->info_ptr, PNG_INFO_tRNS)) != 0;
244 if (context->channels == 0) { // Autodetect number of channels
245 if (context->color_type == PNG_COLOR_TYPE_PALETTE) {
246 if (has_tRNS) {
247 context->channels = 4; // RGB + A(tRNS)
248 } else {
249 context->channels = 3; // RGB
250 }
251 } else {
252 context->channels = png_get_channels(context->png_ptr, context->info_ptr);
253 }
254 }
255 const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0;
256 if ((context->channels & 1) == 0) { // We desire alpha
257 if (has_alpha) { // There is alpha
258 } else if (has_tRNS) {
259 png_set_tRNS_to_alpha(context->png_ptr); // Convert transparency to alpha
260 } else {
261 png_set_add_alpha(context->png_ptr, (1 << context->bit_depth) - 1,
262 PNG_FILLER_AFTER);
263 }
264 } else { // We don't want alpha
265 if (has_alpha || has_tRNS) { // There is alpha
266 png_set_strip_alpha(context->png_ptr); // Strip alpha
267 }
268 }
269
270 // If we only want 8 bits, but are given 16, strip off the LS 8 bits
271 if (context->bit_depth > 8 && desired_channel_bits <= 8)
272 png_set_strip_16(context->png_ptr);
273
274 context->need_to_synthesize_16 =
275 (context->bit_depth <= 8 && desired_channel_bits == 16);
276
277 png_set_packing(context->png_ptr);
278 context->num_passes = png_set_interlace_handling(context->png_ptr);
279
280 if (desired_channel_bits > 8 && port::kLittleEndian) {
281 png_set_swap(context->png_ptr);
282 }
283
284 // convert palette to rgb(a) if needs be.
285 if (context->color_type == PNG_COLOR_TYPE_PALETTE)
286 png_set_palette_to_rgb(context->png_ptr);
287
288 // handle grayscale case for source or destination
289 const bool want_gray = (context->channels < 3);
290 const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR);
291 if (is_gray) { // upconvert gray to 8-bit if needed.
292 if (context->bit_depth < 8) {
293 png_set_expand_gray_1_2_4_to_8(context->png_ptr);
294 }
295 }
296 if (want_gray) { // output is grayscale
297 if (!is_gray)
298 png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587); // 601, JPG
299 } else { // output is rgb(a)
300 if (is_gray)
301 png_set_gray_to_rgb(context->png_ptr); // Enable gray -> RGB conversion
302 }
303
304 // Must come last to incorporate all requested transformations.
305 png_read_update_info(context->png_ptr, context->info_ptr);
306 return true;
307 }
308
CommonFinishDecode(png_bytep data,int row_bytes,DecodeContext * context)309 bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
310 CHECK_NOTNULL(data);
311
312 // we need to re-set the jump point so that we trap the errors
313 // within *this* function (and not CommonInitDecode())
314 if (setjmp(png_jmpbuf(context->png_ptr))) {
315 VLOG(1) << ": DecodePNG error trapped.";
316 CommonFreeDecode(context);
317 return false;
318 }
319 // png_read_row() takes care of offsetting the pointer based on interlacing
320 for (int p = 0; p < context->num_passes; ++p) {
321 png_bytep row = data;
322 for (int h = context->height; h-- != 0; row += row_bytes) {
323 png_read_row(context->png_ptr, row, nullptr);
324 }
325 }
326
327 // Marks iDAT as valid.
328 png_set_rows(context->png_ptr, context->info_ptr,
329 png_get_rows(context->png_ptr, context->info_ptr));
330 png_read_end(context->png_ptr, context->info_ptr);
331
332 // Clean up.
333 const bool ok = !context->error_condition;
334 CommonFreeDecode(context);
335
336 // Synthesize 16 bits from 8 if requested.
337 if (context->need_to_synthesize_16)
338 Convert8to16(absl::bit_cast<uint8*>(data), context->channels, row_bytes,
339 context->width, context->height, absl::bit_cast<uint16*>(data),
340 row_bytes);
341 return ok;
342 }
343
344 template <typename T>
WriteImageToBuffer(const void * image,int width,int height,int row_bytes,int num_channels,int channel_bits,int compression,T * png_string,const std::vector<std::pair<string,string>> * metadata)345 bool WriteImageToBuffer(
346 const void* image, int width, int height, int row_bytes, int num_channels,
347 int channel_bits, int compression, T* png_string,
348 const std::vector<std::pair<string, string> >* metadata) {
349 CHECK_NOTNULL(image);
350 CHECK_NOTNULL(png_string);
351 // Although this case is checked inside png.cc and issues an error message,
352 // that error causes memory corruption.
353 if (width == 0 || height == 0) return false;
354
355 png_string->resize(0);
356 png_infop info_ptr = nullptr;
357 png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr,
358 ErrorHandler, WarningHandler);
359 if (png_ptr == nullptr) return false;
360 if (setjmp(png_jmpbuf(png_ptr))) {
361 png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : nullptr);
362 return false;
363 }
364 info_ptr = png_create_info_struct(png_ptr);
365 if (info_ptr == nullptr) {
366 png_destroy_write_struct(&png_ptr, nullptr);
367 return false;
368 }
369
370 int color_type = -1;
371 switch (num_channels) {
372 case 1:
373 color_type = PNG_COLOR_TYPE_GRAY;
374 break;
375 case 2:
376 color_type = PNG_COLOR_TYPE_GRAY_ALPHA;
377 break;
378 case 3:
379 color_type = PNG_COLOR_TYPE_RGB;
380 break;
381 case 4:
382 color_type = PNG_COLOR_TYPE_RGB_ALPHA;
383 break;
384 default:
385 png_destroy_write_struct(&png_ptr, &info_ptr);
386 return false;
387 }
388
389 png_set_write_fn(png_ptr, png_string, StringWriter<T>, StringWriterFlush);
390 if (compression < 0) compression = Z_DEFAULT_COMPRESSION;
391 png_set_compression_level(png_ptr, compression);
392 png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL);
393 // There used to be a call to png_set_filter here turning off filtering
394 // entirely, but it produced pessimal compression ratios. I'm not sure
395 // why it was there.
396 png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type,
397 PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT,
398 PNG_FILTER_TYPE_DEFAULT);
399 // If we have metadata write to it.
400 if (metadata && !metadata->empty()) {
401 std::vector<png_text> text;
402 for (const auto& pair : *metadata) {
403 png_text txt;
404 txt.compression = PNG_TEXT_COMPRESSION_NONE;
405 txt.key = check_metadata_string(pair.first);
406 txt.text = check_metadata_string(pair.second);
407 text.push_back(txt);
408 }
409 png_set_text(png_ptr, info_ptr, &text[0], text.size());
410 }
411
412 png_write_info(png_ptr, info_ptr);
413 if (channel_bits > 8 && port::kLittleEndian) png_set_swap(png_ptr);
414
415 png_byte* row = reinterpret_cast<png_byte*>(const_cast<void*>(image));
416 for (; height--; row += row_bytes) png_write_row(png_ptr, row);
417 png_write_end(png_ptr, nullptr);
418
419 png_destroy_write_struct(&png_ptr, &info_ptr);
420 return true;
421 }
422
423 template bool WriteImageToBuffer<string>(
424 const void* image, int width, int height, int row_bytes, int num_channels,
425 int channel_bits, int compression, string* png_string,
426 const std::vector<std::pair<string, string> >* metadata);
427 template bool WriteImageToBuffer<tstring>(
428 const void* image, int width, int height, int row_bytes, int num_channels,
429 int channel_bits, int compression, tstring* png_string,
430 const std::vector<std::pair<string, string> >* metadata);
431
432 } // namespace png
433 } // namespace tensorflow
434