• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Functions to read and write images in PNG format.
17 
18 #include <string.h>
19 #include <sys/types.h>
20 #include <zlib.h>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 // NOTE(skal): we don't '#include <setjmp.h>' before png.h as it otherwise
25 // provokes a compile error. We instead let png.h include what is needed.
26 
27 #include "absl/base/casts.h"
28 #include "tensorflow/core/lib/png/png_io.h"
29 #include "tensorflow/core/platform/byte_order.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/png.h"
32 
33 namespace tensorflow {
34 namespace png {
35 
36 ////////////////////////////////////////////////////////////////////////////////
37 // Encode an 8- or 16-bit rgb/grayscale image to PNG string
38 ////////////////////////////////////////////////////////////////////////////////
39 
40 namespace {
41 
42 #define PTR_INC(type, ptr, del) \
43   (ptr = reinterpret_cast<type*>(reinterpret_cast<char*>(ptr) + (del)))
44 #define CPTR_INC(type, ptr, del)                                            \
45   (ptr = reinterpret_cast<const type*>(reinterpret_cast<const char*>(ptr) + \
46                                        (del)))
47 
48 // Convert from 8 bit components to 16. This works in-place.
Convert8to16(const uint8 * p8,int num_comps,int p8_row_bytes,int width,int height_in,uint16 * p16,int p16_row_bytes)49 static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes,
50                          int width, int height_in, uint16* p16,
51                          int p16_row_bytes) {
52   // Force height*row_bytes computations to use 64 bits. Height*width is
53   // enforced to < 29 bits in decode_png_op.cc, but height*row_bytes is
54   // height*width*channels*(8bit?1:2) which is therefore only constrained to <
55   // 33 bits.
56   int64 height = static_cast<int64>(height_in);
57 
58   // Adjust pointers to copy backwards
59   width *= num_comps;
60   CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes + (width - 1) * sizeof(*p8));
61   PTR_INC(uint16, p16,
62           (height - 1) * p16_row_bytes + (width - 1) * sizeof(*p16));
63   int bump8 = width * sizeof(*p8) - p8_row_bytes;
64   int bump16 = width * sizeof(*p16) - p16_row_bytes;
65   for (; height-- != 0;
66        CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) {
67     for (int w = width; w-- != 0; --p8, --p16) {
68       uint32 pix = *p8;
69       pix |= pix << 8;
70       *p16 = static_cast<uint16>(pix);
71     }
72   }
73 }
74 
75 #undef PTR_INC
76 #undef CPTR_INC
77 
ErrorHandler(png_structp png_ptr,png_const_charp msg)78 void ErrorHandler(png_structp png_ptr, png_const_charp msg) {
79   DecodeContext* const ctx =
80       absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
81   ctx->error_condition = true;
82   // To prevent log spam, errors are logged as VLOG(1) instead of ERROR.
83   VLOG(1) << "PNG error: " << msg;
84   longjmp(png_jmpbuf(png_ptr), 1);
85 }
86 
WarningHandler(png_structp png_ptr,png_const_charp msg)87 void WarningHandler(png_structp png_ptr, png_const_charp msg) {
88   LOG(WARNING) << "PNG warning: " << msg;
89 }
90 
StringReader(png_structp png_ptr,png_bytep data,png_size_t length)91 void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) {
92   DecodeContext* const ctx =
93       absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
94   if (static_cast<png_size_t>(ctx->data_left) < length) {
95     // Don't zero out the data buffer as it has been lazily allocated (copy on
96     // write) and zeroing it out here can produce an OOM. Since the buffer is
97     // only used for reading data from the image, this doesn't result in any
98     // data leak, so it is safe to just leave the buffer be as it is and just
99     // exit with error.
100     png_error(png_ptr, "More bytes requested to read than available");
101   } else {
102     memcpy(data, ctx->data, length);
103     ctx->data += length;
104     ctx->data_left -= length;
105   }
106 }
107 
StringWriter(png_structp png_ptr,png_bytep data,png_size_t length)108 void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) {
109   string* const s = absl::bit_cast<string*>(png_get_io_ptr(png_ptr));
110   s->append(absl::bit_cast<const char*>(data), length);
111 }
112 
StringWriterFlush(png_structp png_ptr)113 void StringWriterFlush(png_structp png_ptr) {}
114 
check_metadata_string(const string & s)115 char* check_metadata_string(const string& s) {
116   const char* const c_string = s.c_str();
117   const size_t length = s.size();
118   if (strlen(c_string) != length) {
119     LOG(WARNING) << "Warning! Metadata contains \\0 character(s).";
120   }
121   return const_cast<char*>(c_string);
122 }
123 
124 }  // namespace
125 
126 // We move CommonInitDecode() and CommonFinishDecode()
127 // out of the CommonDecode() template to save code space.
CommonFreeDecode(DecodeContext * context)128 void CommonFreeDecode(DecodeContext* context) {
129   if (context->png_ptr) {
130     png_destroy_read_struct(&context->png_ptr,
131                             context->info_ptr ? &context->info_ptr : nullptr,
132                             nullptr);
133     context->png_ptr = nullptr;
134     context->info_ptr = nullptr;
135   }
136 }
137 
DecodeHeader(StringPiece png_string,int * width,int * height,int * components,int * channel_bit_depth,std::vector<std::pair<string,string>> * metadata)138 bool DecodeHeader(StringPiece png_string, int* width, int* height,
139                   int* components, int* channel_bit_depth,
140                   std::vector<std::pair<string, string> >* metadata) {
141   DecodeContext context;
142   // Ask for 16 bits even if there may be fewer.  This assures that sniffing
143   // the metadata will succeed in all cases.
144   //
145   // TODO(skal): CommonInitDecode() mixes the operation of sniffing the
146   // metadata with setting up the data conversions.  These should be separated.
147   constexpr int kDesiredNumChannels = 1;
148   constexpr int kDesiredChannelBits = 16;
149   if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits,
150                         &context)) {
151     return false;
152   }
153   CHECK_NOTNULL(width);
154   *width = static_cast<int>(context.width);
155   CHECK_NOTNULL(height);
156   *height = static_cast<int>(context.height);
157   if (components != nullptr) {
158     switch (context.color_type) {
159       case PNG_COLOR_TYPE_PALETTE:
160         *components =
161             (png_get_valid(context.png_ptr, context.info_ptr, PNG_INFO_tRNS))
162                 ? 4
163                 : 3;
164         break;
165       case PNG_COLOR_TYPE_GRAY:
166         *components = 1;
167         break;
168       case PNG_COLOR_TYPE_GRAY_ALPHA:
169         *components = 2;
170         break;
171       case PNG_COLOR_TYPE_RGB:
172         *components = 3;
173         break;
174       case PNG_COLOR_TYPE_RGB_ALPHA:
175         *components = 4;
176         break;
177       default:
178         *components = 0;
179         break;
180     }
181   }
182   if (channel_bit_depth != nullptr) {
183     *channel_bit_depth = context.bit_depth;
184   }
185   if (metadata != nullptr) {
186     metadata->clear();
187     png_textp text_ptr = nullptr;
188     int num_text = 0;
189     png_get_text(context.png_ptr, context.info_ptr, &text_ptr, &num_text);
190     for (int i = 0; i < num_text; i++) {
191       const png_text& text = text_ptr[i];
192       metadata->push_back(std::make_pair(text.key, text.text));
193     }
194   }
195   CommonFreeDecode(&context);
196   return true;
197 }
198 
CommonInitDecode(StringPiece png_string,int desired_channels,int desired_channel_bits,DecodeContext * context)199 bool CommonInitDecode(StringPiece png_string, int desired_channels,
200                       int desired_channel_bits, DecodeContext* context) {
201   CHECK(desired_channel_bits == 8 || desired_channel_bits == 16)
202       << "desired_channel_bits = " << desired_channel_bits;
203   CHECK(0 <= desired_channels && desired_channels <= 4)
204       << "desired_channels = " << desired_channels;
205   context->error_condition = false;
206   context->channels = desired_channels;
207   context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context,
208                                             ErrorHandler, WarningHandler);
209   if (!context->png_ptr) {
210     VLOG(1) << ": DecodePNG <- png_create_read_struct failed";
211     return false;
212   }
213   if (setjmp(png_jmpbuf(context->png_ptr))) {
214     VLOG(1) << ": DecodePNG error trapped.";
215     CommonFreeDecode(context);
216     return false;
217   }
218   context->info_ptr = png_create_info_struct(context->png_ptr);
219   if (!context->info_ptr || context->error_condition) {
220     VLOG(1) << ": DecodePNG <- png_create_info_struct failed";
221     CommonFreeDecode(context);
222     return false;
223   }
224   context->data = absl::bit_cast<const uint8*>(png_string.data());
225   context->data_left = png_string.size();
226   png_set_read_fn(context->png_ptr, context, StringReader);
227   png_read_info(context->png_ptr, context->info_ptr);
228   png_get_IHDR(context->png_ptr, context->info_ptr, &context->width,
229                &context->height, &context->bit_depth, &context->color_type,
230                nullptr, nullptr, nullptr);
231   if (context->error_condition) {
232     VLOG(1) << ": DecodePNG <- error during header parsing.";
233     CommonFreeDecode(context);
234     return false;
235   }
236   if (context->width <= 0 || context->height <= 0) {
237     VLOG(1) << ": DecodePNG <- invalid dimensions";
238     CommonFreeDecode(context);
239     return false;
240   }
241   const bool has_tRNS =
242       (png_get_valid(context->png_ptr, context->info_ptr, PNG_INFO_tRNS)) != 0;
243   if (context->channels == 0) {  // Autodetect number of channels
244     if (context->color_type == PNG_COLOR_TYPE_PALETTE) {
245       if (has_tRNS) {
246         context->channels = 4;  // RGB + A(tRNS)
247       } else {
248         context->channels = 3;  // RGB
249       }
250     } else {
251       context->channels = png_get_channels(context->png_ptr, context->info_ptr);
252     }
253   }
254   const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0;
255   if ((context->channels & 1) == 0) {  // We desire alpha
256     if (has_alpha) {                   // There is alpha
257     } else if (has_tRNS) {
258       png_set_tRNS_to_alpha(context->png_ptr);  // Convert transparency to alpha
259     } else {
260       png_set_add_alpha(context->png_ptr, (1 << context->bit_depth) - 1,
261                         PNG_FILLER_AFTER);
262     }
263   } else {                                    // We don't want alpha
264     if (has_alpha || has_tRNS) {              // There is alpha
265       png_set_strip_alpha(context->png_ptr);  // Strip alpha
266     }
267   }
268 
269   // If we only want 8 bits, but are given 16, strip off the LS 8 bits
270   if (context->bit_depth > 8 && desired_channel_bits <= 8)
271     png_set_strip_16(context->png_ptr);
272 
273   context->need_to_synthesize_16 =
274       (context->bit_depth <= 8 && desired_channel_bits == 16);
275 
276   png_set_packing(context->png_ptr);
277   context->num_passes = png_set_interlace_handling(context->png_ptr);
278 
279   if (desired_channel_bits > 8 && port::kLittleEndian) {
280     png_set_swap(context->png_ptr);
281   }
282 
283   // convert palette to rgb(a) if needs be.
284   if (context->color_type == PNG_COLOR_TYPE_PALETTE)
285     png_set_palette_to_rgb(context->png_ptr);
286 
287   // handle grayscale case for source or destination
288   const bool want_gray = (context->channels < 3);
289   const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR);
290   if (is_gray) {  // upconvert gray to 8-bit if needed.
291     if (context->bit_depth < 8) {
292       png_set_expand_gray_1_2_4_to_8(context->png_ptr);
293     }
294   }
295   if (want_gray) {  // output is grayscale
296     if (!is_gray)
297       png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587);  // 601, JPG
298   } else {  // output is rgb(a)
299     if (is_gray)
300       png_set_gray_to_rgb(context->png_ptr);  // Enable gray -> RGB conversion
301   }
302 
303   // Must come last to incorporate all requested transformations.
304   png_read_update_info(context->png_ptr, context->info_ptr);
305   return true;
306 }
307 
CommonFinishDecode(png_bytep data,int row_bytes,DecodeContext * context)308 bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
309   CHECK_NOTNULL(data);
310 
311   // we need to re-set the jump point so that we trap the errors
312   // within *this* function (and not CommonInitDecode())
313   if (setjmp(png_jmpbuf(context->png_ptr))) {
314     VLOG(1) << ": DecodePNG error trapped.";
315     CommonFreeDecode(context);
316     return false;
317   }
318   // png_read_row() takes care of offsetting the pointer based on interlacing
319   for (int p = 0; p < context->num_passes; ++p) {
320     png_bytep row = data;
321     for (int h = context->height; h-- != 0; row += row_bytes) {
322       png_read_row(context->png_ptr, row, nullptr);
323     }
324   }
325 
326   // Marks iDAT as valid.
327   png_set_rows(context->png_ptr, context->info_ptr,
328                png_get_rows(context->png_ptr, context->info_ptr));
329   png_read_end(context->png_ptr, context->info_ptr);
330 
331   // Clean up.
332   const bool ok = !context->error_condition;
333   CommonFreeDecode(context);
334 
335   // Synthesize 16 bits from 8 if requested.
336   if (context->need_to_synthesize_16)
337     Convert8to16(absl::bit_cast<uint8*>(data), context->channels, row_bytes,
338                  context->width, context->height, absl::bit_cast<uint16*>(data),
339                  row_bytes);
340   return ok;
341 }
342 
WriteImageToBuffer(const void * image,int width,int height,int row_bytes,int num_channels,int channel_bits,int compression,string * png_string,const std::vector<std::pair<string,string>> * metadata)343 bool WriteImageToBuffer(
344     const void* image, int width, int height, int row_bytes, int num_channels,
345     int channel_bits, int compression, string* png_string,
346     const std::vector<std::pair<string, string> >* metadata) {
347   CHECK_NOTNULL(image);
348   CHECK_NOTNULL(png_string);
349   // Although this case is checked inside png.cc and issues an error message,
350   // that error causes memory corruption.
351   if (width == 0 || height == 0) return false;
352 
353   png_string->resize(0);
354   png_infop info_ptr = nullptr;
355   png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr,
356                                                 ErrorHandler, WarningHandler);
357   if (png_ptr == nullptr) return false;
358   if (setjmp(png_jmpbuf(png_ptr))) {
359     png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : nullptr);
360     return false;
361   }
362   info_ptr = png_create_info_struct(png_ptr);
363   if (info_ptr == nullptr) {
364     png_destroy_write_struct(&png_ptr, nullptr);
365     return false;
366   }
367 
368   int color_type = -1;
369   switch (num_channels) {
370     case 1:
371       color_type = PNG_COLOR_TYPE_GRAY;
372       break;
373     case 2:
374       color_type = PNG_COLOR_TYPE_GRAY_ALPHA;
375       break;
376     case 3:
377       color_type = PNG_COLOR_TYPE_RGB;
378       break;
379     case 4:
380       color_type = PNG_COLOR_TYPE_RGB_ALPHA;
381       break;
382     default:
383       png_destroy_write_struct(&png_ptr, &info_ptr);
384       return false;
385   }
386 
387   png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush);
388   if (compression < 0) compression = Z_DEFAULT_COMPRESSION;
389   png_set_compression_level(png_ptr, compression);
390   png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL);
391   // There used to be a call to png_set_filter here turning off filtering
392   // entirely, but it produced pessimal compression ratios.  I'm not sure
393   // why it was there.
394   png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type,
395                PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT,
396                PNG_FILTER_TYPE_DEFAULT);
397   // If we have metadata write to it.
398   if (metadata && !metadata->empty()) {
399     std::vector<png_text> text;
400     for (const auto& pair : *metadata) {
401       png_text txt;
402       txt.compression = PNG_TEXT_COMPRESSION_NONE;
403       txt.key = check_metadata_string(pair.first);
404       txt.text = check_metadata_string(pair.second);
405       text.push_back(txt);
406     }
407     png_set_text(png_ptr, info_ptr, &text[0], text.size());
408   }
409 
410   png_write_info(png_ptr, info_ptr);
411   if (channel_bits > 8 && port::kLittleEndian) png_set_swap(png_ptr);
412 
413   png_byte* row = reinterpret_cast<png_byte*>(const_cast<void*>(image));
414   for (; height--; row += row_bytes) png_write_row(png_ptr, row);
415   png_write_end(png_ptr, nullptr);
416 
417   png_destroy_write_struct(&png_ptr, &info_ptr);
418   return true;
419 }
420 
421 }  // namespace png
422 }  // namespace tensorflow
423