• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Functions to read and write images in PNG format.
17 
18 #include <string.h>
19 #include <sys/types.h>
20 #include <zlib.h>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 // NOTE(skal): we don't '#include <setjmp.h>' before png.h as it otherwise
25 // provokes a compile error. We instead let png.h include what is needed.
26 
27 #include "absl/base/casts.h"
28 #include "tensorflow/core/lib/png/png_io.h"
29 #include "tensorflow/core/platform/byte_order.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/png.h"
32 
33 namespace tensorflow {
34 namespace png {
35 
36 ////////////////////////////////////////////////////////////////////////////////
37 // Encode an 8- or 16-bit rgb/grayscale image to PNG string
38 ////////////////////////////////////////////////////////////////////////////////
39 
40 namespace {
41 
42 #define PTR_INC(type, ptr, del) \
43   (ptr = reinterpret_cast<type*>(reinterpret_cast<char*>(ptr) + (del)))
44 #define CPTR_INC(type, ptr, del)                                            \
45   (ptr = reinterpret_cast<const type*>(reinterpret_cast<const char*>(ptr) + \
46                                        (del)))
47 
48 // Convert from 8 bit components to 16. This works in-place.
Convert8to16(const uint8 * p8,int num_comps,int p8_row_bytes,int width,int height_in,uint16 * p16,int p16_row_bytes)49 static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes,
50                          int width, int height_in, uint16* p16,
51                          int p16_row_bytes) {
52   // Force height*row_bytes computations to use 64 bits. Height*width is
53   // enforced to < 29 bits in decode_png_op.cc, but height*row_bytes is
54   // height*width*channels*(8bit?1:2) which is therefore only constrained to <
55   // 33 bits.
56   int64 height = static_cast<int64>(height_in);
57 
58   // Adjust pointers to copy backwards
59   width *= num_comps;
60   CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes + (width - 1) * sizeof(*p8));
61   PTR_INC(uint16, p16,
62           (height - 1) * p16_row_bytes + (width - 1) * sizeof(*p16));
63   int bump8 = width * sizeof(*p8) - p8_row_bytes;
64   int bump16 = width * sizeof(*p16) - p16_row_bytes;
65   for (; height-- != 0;
66        CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) {
67     for (int w = width; w-- != 0; --p8, --p16) {
68       uint32 pix = *p8;
69       pix |= pix << 8;
70       *p16 = static_cast<uint16>(pix);
71     }
72   }
73 }
74 
75 #undef PTR_INC
76 #undef CPTR_INC
77 
ErrorHandler(png_structp png_ptr,png_const_charp msg)78 void ErrorHandler(png_structp png_ptr, png_const_charp msg) {
79   DecodeContext* const ctx =
80       absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
81   ctx->error_condition = true;
82   // To prevent log spam, errors are logged as VLOG(1) instead of ERROR.
83   VLOG(1) << "PNG error: " << msg;
84   longjmp(png_jmpbuf(png_ptr), 1);
85 }
86 
WarningHandler(png_structp png_ptr,png_const_charp msg)87 void WarningHandler(png_structp png_ptr, png_const_charp msg) {
88   LOG(WARNING) << "PNG warning: " << msg;
89 }
90 
StringReader(png_structp png_ptr,png_bytep data,png_size_t length)91 void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) {
92   DecodeContext* const ctx =
93       absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
94   if (static_cast<png_size_t>(ctx->data_left) < length) {
95     // Don't zero out the data buffer as it has been lazily allocated (copy on
96     // write) and zeroing it out here can produce an OOM. Since the buffer is
97     // only used for reading data from the image, this doesn't result in any
98     // data leak, so it is safe to just leave the buffer be as it is and just
99     // exit with error.
100     png_error(png_ptr, "More bytes requested to read than available");
101   } else {
102     memcpy(data, ctx->data, length);
103     ctx->data += length;
104     ctx->data_left -= length;
105   }
106 }
107 
108 template <typename T>
StringWriter(png_structp png_ptr,png_bytep data,png_size_t length)109 void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) {
110   T* const s = absl::bit_cast<T*>(png_get_io_ptr(png_ptr));
111   s->append(absl::bit_cast<const char*>(data), length);
112 }
113 
StringWriterFlush(png_structp png_ptr)114 void StringWriterFlush(png_structp png_ptr) {}
115 
check_metadata_string(const string & s)116 char* check_metadata_string(const string& s) {
117   const char* const c_string = s.c_str();
118   const size_t length = s.size();
119   if (strlen(c_string) != length) {
120     LOG(WARNING) << "Warning! Metadata contains \\0 character(s).";
121   }
122   return const_cast<char*>(c_string);
123 }
124 
125 }  // namespace
126 
127 // We move CommonInitDecode() and CommonFinishDecode()
128 // out of the CommonDecode() template to save code space.
CommonFreeDecode(DecodeContext * context)129 void CommonFreeDecode(DecodeContext* context) {
130   if (context->png_ptr) {
131     png_destroy_read_struct(&context->png_ptr,
132                             context->info_ptr ? &context->info_ptr : nullptr,
133                             nullptr);
134     context->png_ptr = nullptr;
135     context->info_ptr = nullptr;
136   }
137 }
138 
DecodeHeader(StringPiece png_string,int * width,int * height,int * components,int * channel_bit_depth,std::vector<std::pair<string,string>> * metadata)139 bool DecodeHeader(StringPiece png_string, int* width, int* height,
140                   int* components, int* channel_bit_depth,
141                   std::vector<std::pair<string, string> >* metadata) {
142   DecodeContext context;
143   // Ask for 16 bits even if there may be fewer.  This assures that sniffing
144   // the metadata will succeed in all cases.
145   //
146   // TODO(skal): CommonInitDecode() mixes the operation of sniffing the
147   // metadata with setting up the data conversions.  These should be separated.
148   constexpr int kDesiredNumChannels = 1;
149   constexpr int kDesiredChannelBits = 16;
150   if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits,
151                         &context)) {
152     return false;
153   }
154   CHECK_NOTNULL(width);
155   *width = static_cast<int>(context.width);
156   CHECK_NOTNULL(height);
157   *height = static_cast<int>(context.height);
158   if (components != nullptr) {
159     switch (context.color_type) {
160       case PNG_COLOR_TYPE_PALETTE:
161         *components =
162             (png_get_valid(context.png_ptr, context.info_ptr, PNG_INFO_tRNS))
163                 ? 4
164                 : 3;
165         break;
166       case PNG_COLOR_TYPE_GRAY:
167         *components = 1;
168         break;
169       case PNG_COLOR_TYPE_GRAY_ALPHA:
170         *components = 2;
171         break;
172       case PNG_COLOR_TYPE_RGB:
173         *components = 3;
174         break;
175       case PNG_COLOR_TYPE_RGB_ALPHA:
176         *components = 4;
177         break;
178       default:
179         *components = 0;
180         break;
181     }
182   }
183   if (channel_bit_depth != nullptr) {
184     *channel_bit_depth = context.bit_depth;
185   }
186   if (metadata != nullptr) {
187     metadata->clear();
188     png_textp text_ptr = nullptr;
189     int num_text = 0;
190     png_get_text(context.png_ptr, context.info_ptr, &text_ptr, &num_text);
191     for (int i = 0; i < num_text; i++) {
192       const png_text& text = text_ptr[i];
193       metadata->push_back(std::make_pair(text.key, text.text));
194     }
195   }
196   CommonFreeDecode(&context);
197   return true;
198 }
199 
CommonInitDecode(StringPiece png_string,int desired_channels,int desired_channel_bits,DecodeContext * context)200 bool CommonInitDecode(StringPiece png_string, int desired_channels,
201                       int desired_channel_bits, DecodeContext* context) {
202   CHECK(desired_channel_bits == 8 || desired_channel_bits == 16)
203       << "desired_channel_bits = " << desired_channel_bits;
204   CHECK(0 <= desired_channels && desired_channels <= 4)
205       << "desired_channels = " << desired_channels;
206   context->error_condition = false;
207   context->channels = desired_channels;
208   context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context,
209                                             ErrorHandler, WarningHandler);
210   if (!context->png_ptr) {
211     VLOG(1) << ": DecodePNG <- png_create_read_struct failed";
212     return false;
213   }
214   if (setjmp(png_jmpbuf(context->png_ptr))) {
215     VLOG(1) << ": DecodePNG error trapped.";
216     CommonFreeDecode(context);
217     return false;
218   }
219   context->info_ptr = png_create_info_struct(context->png_ptr);
220   if (!context->info_ptr || context->error_condition) {
221     VLOG(1) << ": DecodePNG <- png_create_info_struct failed";
222     CommonFreeDecode(context);
223     return false;
224   }
225   context->data = absl::bit_cast<const uint8*>(png_string.data());
226   context->data_left = png_string.size();
227   png_set_read_fn(context->png_ptr, context, StringReader);
228   png_read_info(context->png_ptr, context->info_ptr);
229   png_get_IHDR(context->png_ptr, context->info_ptr, &context->width,
230                &context->height, &context->bit_depth, &context->color_type,
231                nullptr, nullptr, nullptr);
232   if (context->error_condition) {
233     VLOG(1) << ": DecodePNG <- error during header parsing.";
234     CommonFreeDecode(context);
235     return false;
236   }
237   if (context->width <= 0 || context->height <= 0) {
238     VLOG(1) << ": DecodePNG <- invalid dimensions";
239     CommonFreeDecode(context);
240     return false;
241   }
242   const bool has_tRNS =
243       (png_get_valid(context->png_ptr, context->info_ptr, PNG_INFO_tRNS)) != 0;
244   if (context->channels == 0) {  // Autodetect number of channels
245     if (context->color_type == PNG_COLOR_TYPE_PALETTE) {
246       if (has_tRNS) {
247         context->channels = 4;  // RGB + A(tRNS)
248       } else {
249         context->channels = 3;  // RGB
250       }
251     } else {
252       context->channels = png_get_channels(context->png_ptr, context->info_ptr);
253     }
254   }
255   const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0;
256   if ((context->channels & 1) == 0) {  // We desire alpha
257     if (has_alpha) {                   // There is alpha
258     } else if (has_tRNS) {
259       png_set_tRNS_to_alpha(context->png_ptr);  // Convert transparency to alpha
260     } else {
261       png_set_add_alpha(context->png_ptr, (1 << context->bit_depth) - 1,
262                         PNG_FILLER_AFTER);
263     }
264   } else {                                    // We don't want alpha
265     if (has_alpha || has_tRNS) {              // There is alpha
266       png_set_strip_alpha(context->png_ptr);  // Strip alpha
267     }
268   }
269 
270   // If we only want 8 bits, but are given 16, strip off the LS 8 bits
271   if (context->bit_depth > 8 && desired_channel_bits <= 8)
272     png_set_strip_16(context->png_ptr);
273 
274   context->need_to_synthesize_16 =
275       (context->bit_depth <= 8 && desired_channel_bits == 16);
276 
277   png_set_packing(context->png_ptr);
278   context->num_passes = png_set_interlace_handling(context->png_ptr);
279 
280   if (desired_channel_bits > 8 && port::kLittleEndian) {
281     png_set_swap(context->png_ptr);
282   }
283 
284   // convert palette to rgb(a) if needs be.
285   if (context->color_type == PNG_COLOR_TYPE_PALETTE)
286     png_set_palette_to_rgb(context->png_ptr);
287 
288   // handle grayscale case for source or destination
289   const bool want_gray = (context->channels < 3);
290   const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR);
291   if (is_gray) {  // upconvert gray to 8-bit if needed.
292     if (context->bit_depth < 8) {
293       png_set_expand_gray_1_2_4_to_8(context->png_ptr);
294     }
295   }
296   if (want_gray) {  // output is grayscale
297     if (!is_gray)
298       png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587);  // 601, JPG
299   } else {  // output is rgb(a)
300     if (is_gray)
301       png_set_gray_to_rgb(context->png_ptr);  // Enable gray -> RGB conversion
302   }
303 
304   // Must come last to incorporate all requested transformations.
305   png_read_update_info(context->png_ptr, context->info_ptr);
306   return true;
307 }
308 
CommonFinishDecode(png_bytep data,int row_bytes,DecodeContext * context)309 bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
310   CHECK_NOTNULL(data);
311 
312   // we need to re-set the jump point so that we trap the errors
313   // within *this* function (and not CommonInitDecode())
314   if (setjmp(png_jmpbuf(context->png_ptr))) {
315     VLOG(1) << ": DecodePNG error trapped.";
316     CommonFreeDecode(context);
317     return false;
318   }
319   // png_read_row() takes care of offsetting the pointer based on interlacing
320   for (int p = 0; p < context->num_passes; ++p) {
321     png_bytep row = data;
322     for (int h = context->height; h-- != 0; row += row_bytes) {
323       png_read_row(context->png_ptr, row, nullptr);
324     }
325   }
326 
327   // Marks iDAT as valid.
328   png_set_rows(context->png_ptr, context->info_ptr,
329                png_get_rows(context->png_ptr, context->info_ptr));
330   png_read_end(context->png_ptr, context->info_ptr);
331 
332   // Clean up.
333   const bool ok = !context->error_condition;
334   CommonFreeDecode(context);
335 
336   // Synthesize 16 bits from 8 if requested.
337   if (context->need_to_synthesize_16)
338     Convert8to16(absl::bit_cast<uint8*>(data), context->channels, row_bytes,
339                  context->width, context->height, absl::bit_cast<uint16*>(data),
340                  row_bytes);
341   return ok;
342 }
343 
344 template <typename T>
WriteImageToBuffer(const void * image,int width,int height,int row_bytes,int num_channels,int channel_bits,int compression,T * png_string,const std::vector<std::pair<string,string>> * metadata)345 bool WriteImageToBuffer(
346     const void* image, int width, int height, int row_bytes, int num_channels,
347     int channel_bits, int compression, T* png_string,
348     const std::vector<std::pair<string, string> >* metadata) {
349   CHECK_NOTNULL(image);
350   CHECK_NOTNULL(png_string);
351   // Although this case is checked inside png.cc and issues an error message,
352   // that error causes memory corruption.
353   if (width == 0 || height == 0) return false;
354 
355   png_string->resize(0);
356   png_infop info_ptr = nullptr;
357   png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr,
358                                                 ErrorHandler, WarningHandler);
359   if (png_ptr == nullptr) return false;
360   if (setjmp(png_jmpbuf(png_ptr))) {
361     png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : nullptr);
362     return false;
363   }
364   info_ptr = png_create_info_struct(png_ptr);
365   if (info_ptr == nullptr) {
366     png_destroy_write_struct(&png_ptr, nullptr);
367     return false;
368   }
369 
370   int color_type = -1;
371   switch (num_channels) {
372     case 1:
373       color_type = PNG_COLOR_TYPE_GRAY;
374       break;
375     case 2:
376       color_type = PNG_COLOR_TYPE_GRAY_ALPHA;
377       break;
378     case 3:
379       color_type = PNG_COLOR_TYPE_RGB;
380       break;
381     case 4:
382       color_type = PNG_COLOR_TYPE_RGB_ALPHA;
383       break;
384     default:
385       png_destroy_write_struct(&png_ptr, &info_ptr);
386       return false;
387   }
388 
389   png_set_write_fn(png_ptr, png_string, StringWriter<T>, StringWriterFlush);
390   if (compression < 0) compression = Z_DEFAULT_COMPRESSION;
391   png_set_compression_level(png_ptr, compression);
392   png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL);
393   // There used to be a call to png_set_filter here turning off filtering
394   // entirely, but it produced pessimal compression ratios.  I'm not sure
395   // why it was there.
396   png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type,
397                PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT,
398                PNG_FILTER_TYPE_DEFAULT);
399   // If we have metadata write to it.
400   if (metadata && !metadata->empty()) {
401     std::vector<png_text> text;
402     for (const auto& pair : *metadata) {
403       png_text txt;
404       txt.compression = PNG_TEXT_COMPRESSION_NONE;
405       txt.key = check_metadata_string(pair.first);
406       txt.text = check_metadata_string(pair.second);
407       text.push_back(txt);
408     }
409     png_set_text(png_ptr, info_ptr, &text[0], text.size());
410   }
411 
412   png_write_info(png_ptr, info_ptr);
413   if (channel_bits > 8 && port::kLittleEndian) png_set_swap(png_ptr);
414 
415   png_byte* row = reinterpret_cast<png_byte*>(const_cast<void*>(image));
416   for (; height--; row += row_bytes) png_write_row(png_ptr, row);
417   png_write_end(png_ptr, nullptr);
418 
419   png_destroy_write_struct(&png_ptr, &info_ptr);
420   return true;
421 }
422 
423 template bool WriteImageToBuffer<string>(
424     const void* image, int width, int height, int row_bytes, int num_channels,
425     int channel_bits, int compression, string* png_string,
426     const std::vector<std::pair<string, string> >* metadata);
427 template bool WriteImageToBuffer<tstring>(
428     const void* image, int width, int height, int row_bytes, int num_channels,
429     int channel_bits, int compression, tstring* png_string,
430     const std::vector<std::pair<string, string> >* metadata);
431 
432 }  // namespace png
433 }  // namespace tensorflow
434