1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Functions to read and write images in PNG format.
17
18 #include <string.h>
19 #include <sys/types.h>
20 #include <zlib.h>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 // NOTE(skal): we don't '#include <setjmp.h>' before png.h as it otherwise
25 // provokes a compile error. We instead let png.h include what is needed.
26
27 #include "absl/base/casts.h"
28 #include "tensorflow/core/lib/png/png_io.h"
29 #include "tensorflow/core/platform/byte_order.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/png.h"
32
33 namespace tensorflow {
34 namespace png {
35
36 ////////////////////////////////////////////////////////////////////////////////
37 // Encode an 8- or 16-bit rgb/grayscale image to PNG string
38 ////////////////////////////////////////////////////////////////////////////////
39
40 namespace {
41
42 #define PTR_INC(type, ptr, del) \
43 (ptr = reinterpret_cast<type*>(reinterpret_cast<char*>(ptr) + (del)))
44 #define CPTR_INC(type, ptr, del) \
45 (ptr = reinterpret_cast<const type*>(reinterpret_cast<const char*>(ptr) + \
46 (del)))
47
48 // Convert from 8 bit components to 16. This works in-place.
Convert8to16(const uint8 * p8,int num_comps,int p8_row_bytes,int width,int height_in,uint16 * p16,int p16_row_bytes)49 static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes,
50 int width, int height_in, uint16* p16,
51 int p16_row_bytes) {
52 // Force height*row_bytes computations to use 64 bits. Height*width is
53 // enforced to < 29 bits in decode_png_op.cc, but height*row_bytes is
54 // height*width*channels*(8bit?1:2) which is therefore only constrained to <
55 // 33 bits.
56 int64 height = static_cast<int64>(height_in);
57
58 // Adjust pointers to copy backwards
59 width *= num_comps;
60 CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes + (width - 1) * sizeof(*p8));
61 PTR_INC(uint16, p16,
62 (height - 1) * p16_row_bytes + (width - 1) * sizeof(*p16));
63 int bump8 = width * sizeof(*p8) - p8_row_bytes;
64 int bump16 = width * sizeof(*p16) - p16_row_bytes;
65 for (; height-- != 0;
66 CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) {
67 for (int w = width; w-- != 0; --p8, --p16) {
68 uint32 pix = *p8;
69 pix |= pix << 8;
70 *p16 = static_cast<uint16>(pix);
71 }
72 }
73 }
74
75 #undef PTR_INC
76 #undef CPTR_INC
77
ErrorHandler(png_structp png_ptr,png_const_charp msg)78 void ErrorHandler(png_structp png_ptr, png_const_charp msg) {
79 DecodeContext* const ctx =
80 absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
81 ctx->error_condition = true;
82 // To prevent log spam, errors are logged as VLOG(1) instead of ERROR.
83 VLOG(1) << "PNG error: " << msg;
84 longjmp(png_jmpbuf(png_ptr), 1);
85 }
86
WarningHandler(png_structp png_ptr,png_const_charp msg)87 void WarningHandler(png_structp png_ptr, png_const_charp msg) {
88 LOG(WARNING) << "PNG warning: " << msg;
89 }
90
StringReader(png_structp png_ptr,png_bytep data,png_size_t length)91 void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) {
92 DecodeContext* const ctx =
93 absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
94 if (static_cast<png_size_t>(ctx->data_left) < length) {
95 // Don't zero out the data buffer as it has been lazily allocated (copy on
96 // write) and zeroing it out here can produce an OOM. Since the buffer is
97 // only used for reading data from the image, this doesn't result in any
98 // data leak, so it is safe to just leave the buffer be as it is and just
99 // exit with error.
100 png_error(png_ptr, "More bytes requested to read than available");
101 } else {
102 memcpy(data, ctx->data, length);
103 ctx->data += length;
104 ctx->data_left -= length;
105 }
106 }
107
StringWriter(png_structp png_ptr,png_bytep data,png_size_t length)108 void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) {
109 string* const s = absl::bit_cast<string*>(png_get_io_ptr(png_ptr));
110 s->append(absl::bit_cast<const char*>(data), length);
111 }
112
StringWriterFlush(png_structp png_ptr)113 void StringWriterFlush(png_structp png_ptr) {}
114
check_metadata_string(const string & s)115 char* check_metadata_string(const string& s) {
116 const char* const c_string = s.c_str();
117 const size_t length = s.size();
118 if (strlen(c_string) != length) {
119 LOG(WARNING) << "Warning! Metadata contains \\0 character(s).";
120 }
121 return const_cast<char*>(c_string);
122 }
123
124 } // namespace
125
126 // We move CommonInitDecode() and CommonFinishDecode()
127 // out of the CommonDecode() template to save code space.
CommonFreeDecode(DecodeContext * context)128 void CommonFreeDecode(DecodeContext* context) {
129 if (context->png_ptr) {
130 png_destroy_read_struct(&context->png_ptr,
131 context->info_ptr ? &context->info_ptr : nullptr,
132 nullptr);
133 context->png_ptr = nullptr;
134 context->info_ptr = nullptr;
135 }
136 }
137
DecodeHeader(StringPiece png_string,int * width,int * height,int * components,int * channel_bit_depth,std::vector<std::pair<string,string>> * metadata)138 bool DecodeHeader(StringPiece png_string, int* width, int* height,
139 int* components, int* channel_bit_depth,
140 std::vector<std::pair<string, string> >* metadata) {
141 DecodeContext context;
142 // Ask for 16 bits even if there may be fewer. This assures that sniffing
143 // the metadata will succeed in all cases.
144 //
145 // TODO(skal): CommonInitDecode() mixes the operation of sniffing the
146 // metadata with setting up the data conversions. These should be separated.
147 constexpr int kDesiredNumChannels = 1;
148 constexpr int kDesiredChannelBits = 16;
149 if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits,
150 &context)) {
151 return false;
152 }
153 CHECK_NOTNULL(width);
154 *width = static_cast<int>(context.width);
155 CHECK_NOTNULL(height);
156 *height = static_cast<int>(context.height);
157 if (components != nullptr) {
158 switch (context.color_type) {
159 case PNG_COLOR_TYPE_PALETTE:
160 *components =
161 (png_get_valid(context.png_ptr, context.info_ptr, PNG_INFO_tRNS))
162 ? 4
163 : 3;
164 break;
165 case PNG_COLOR_TYPE_GRAY:
166 *components = 1;
167 break;
168 case PNG_COLOR_TYPE_GRAY_ALPHA:
169 *components = 2;
170 break;
171 case PNG_COLOR_TYPE_RGB:
172 *components = 3;
173 break;
174 case PNG_COLOR_TYPE_RGB_ALPHA:
175 *components = 4;
176 break;
177 default:
178 *components = 0;
179 break;
180 }
181 }
182 if (channel_bit_depth != nullptr) {
183 *channel_bit_depth = context.bit_depth;
184 }
185 if (metadata != nullptr) {
186 metadata->clear();
187 png_textp text_ptr = nullptr;
188 int num_text = 0;
189 png_get_text(context.png_ptr, context.info_ptr, &text_ptr, &num_text);
190 for (int i = 0; i < num_text; i++) {
191 const png_text& text = text_ptr[i];
192 metadata->push_back(std::make_pair(text.key, text.text));
193 }
194 }
195 CommonFreeDecode(&context);
196 return true;
197 }
198
CommonInitDecode(StringPiece png_string,int desired_channels,int desired_channel_bits,DecodeContext * context)199 bool CommonInitDecode(StringPiece png_string, int desired_channels,
200 int desired_channel_bits, DecodeContext* context) {
201 CHECK(desired_channel_bits == 8 || desired_channel_bits == 16)
202 << "desired_channel_bits = " << desired_channel_bits;
203 CHECK(0 <= desired_channels && desired_channels <= 4)
204 << "desired_channels = " << desired_channels;
205 context->error_condition = false;
206 context->channels = desired_channels;
207 context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context,
208 ErrorHandler, WarningHandler);
209 if (!context->png_ptr) {
210 VLOG(1) << ": DecodePNG <- png_create_read_struct failed";
211 return false;
212 }
213 if (setjmp(png_jmpbuf(context->png_ptr))) {
214 VLOG(1) << ": DecodePNG error trapped.";
215 CommonFreeDecode(context);
216 return false;
217 }
218 context->info_ptr = png_create_info_struct(context->png_ptr);
219 if (!context->info_ptr || context->error_condition) {
220 VLOG(1) << ": DecodePNG <- png_create_info_struct failed";
221 CommonFreeDecode(context);
222 return false;
223 }
224 context->data = absl::bit_cast<const uint8*>(png_string.data());
225 context->data_left = png_string.size();
226 png_set_read_fn(context->png_ptr, context, StringReader);
227 png_read_info(context->png_ptr, context->info_ptr);
228 png_get_IHDR(context->png_ptr, context->info_ptr, &context->width,
229 &context->height, &context->bit_depth, &context->color_type,
230 nullptr, nullptr, nullptr);
231 if (context->error_condition) {
232 VLOG(1) << ": DecodePNG <- error during header parsing.";
233 CommonFreeDecode(context);
234 return false;
235 }
236 if (context->width <= 0 || context->height <= 0) {
237 VLOG(1) << ": DecodePNG <- invalid dimensions";
238 CommonFreeDecode(context);
239 return false;
240 }
241 const bool has_tRNS =
242 (png_get_valid(context->png_ptr, context->info_ptr, PNG_INFO_tRNS)) != 0;
243 if (context->channels == 0) { // Autodetect number of channels
244 if (context->color_type == PNG_COLOR_TYPE_PALETTE) {
245 if (has_tRNS) {
246 context->channels = 4; // RGB + A(tRNS)
247 } else {
248 context->channels = 3; // RGB
249 }
250 } else {
251 context->channels = png_get_channels(context->png_ptr, context->info_ptr);
252 }
253 }
254 const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0;
255 if ((context->channels & 1) == 0) { // We desire alpha
256 if (has_alpha) { // There is alpha
257 } else if (has_tRNS) {
258 png_set_tRNS_to_alpha(context->png_ptr); // Convert transparency to alpha
259 } else {
260 png_set_add_alpha(context->png_ptr, (1 << context->bit_depth) - 1,
261 PNG_FILLER_AFTER);
262 }
263 } else { // We don't want alpha
264 if (has_alpha || has_tRNS) { // There is alpha
265 png_set_strip_alpha(context->png_ptr); // Strip alpha
266 }
267 }
268
269 // If we only want 8 bits, but are given 16, strip off the LS 8 bits
270 if (context->bit_depth > 8 && desired_channel_bits <= 8)
271 png_set_strip_16(context->png_ptr);
272
273 context->need_to_synthesize_16 =
274 (context->bit_depth <= 8 && desired_channel_bits == 16);
275
276 png_set_packing(context->png_ptr);
277 context->num_passes = png_set_interlace_handling(context->png_ptr);
278
279 if (desired_channel_bits > 8 && port::kLittleEndian) {
280 png_set_swap(context->png_ptr);
281 }
282
283 // convert palette to rgb(a) if needs be.
284 if (context->color_type == PNG_COLOR_TYPE_PALETTE)
285 png_set_palette_to_rgb(context->png_ptr);
286
287 // handle grayscale case for source or destination
288 const bool want_gray = (context->channels < 3);
289 const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR);
290 if (is_gray) { // upconvert gray to 8-bit if needed.
291 if (context->bit_depth < 8) {
292 png_set_expand_gray_1_2_4_to_8(context->png_ptr);
293 }
294 }
295 if (want_gray) { // output is grayscale
296 if (!is_gray)
297 png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587); // 601, JPG
298 } else { // output is rgb(a)
299 if (is_gray)
300 png_set_gray_to_rgb(context->png_ptr); // Enable gray -> RGB conversion
301 }
302
303 // Must come last to incorporate all requested transformations.
304 png_read_update_info(context->png_ptr, context->info_ptr);
305 return true;
306 }
307
CommonFinishDecode(png_bytep data,int row_bytes,DecodeContext * context)308 bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
309 CHECK_NOTNULL(data);
310
311 // we need to re-set the jump point so that we trap the errors
312 // within *this* function (and not CommonInitDecode())
313 if (setjmp(png_jmpbuf(context->png_ptr))) {
314 VLOG(1) << ": DecodePNG error trapped.";
315 CommonFreeDecode(context);
316 return false;
317 }
318 // png_read_row() takes care of offsetting the pointer based on interlacing
319 for (int p = 0; p < context->num_passes; ++p) {
320 png_bytep row = data;
321 for (int h = context->height; h-- != 0; row += row_bytes) {
322 png_read_row(context->png_ptr, row, nullptr);
323 }
324 }
325
326 // Marks iDAT as valid.
327 png_set_rows(context->png_ptr, context->info_ptr,
328 png_get_rows(context->png_ptr, context->info_ptr));
329 png_read_end(context->png_ptr, context->info_ptr);
330
331 // Clean up.
332 const bool ok = !context->error_condition;
333 CommonFreeDecode(context);
334
335 // Synthesize 16 bits from 8 if requested.
336 if (context->need_to_synthesize_16)
337 Convert8to16(absl::bit_cast<uint8*>(data), context->channels, row_bytes,
338 context->width, context->height, absl::bit_cast<uint16*>(data),
339 row_bytes);
340 return ok;
341 }
342
WriteImageToBuffer(const void * image,int width,int height,int row_bytes,int num_channels,int channel_bits,int compression,string * png_string,const std::vector<std::pair<string,string>> * metadata)343 bool WriteImageToBuffer(
344 const void* image, int width, int height, int row_bytes, int num_channels,
345 int channel_bits, int compression, string* png_string,
346 const std::vector<std::pair<string, string> >* metadata) {
347 CHECK_NOTNULL(image);
348 CHECK_NOTNULL(png_string);
349 // Although this case is checked inside png.cc and issues an error message,
350 // that error causes memory corruption.
351 if (width == 0 || height == 0) return false;
352
353 png_string->resize(0);
354 png_infop info_ptr = nullptr;
355 png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, nullptr,
356 ErrorHandler, WarningHandler);
357 if (png_ptr == nullptr) return false;
358 if (setjmp(png_jmpbuf(png_ptr))) {
359 png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : nullptr);
360 return false;
361 }
362 info_ptr = png_create_info_struct(png_ptr);
363 if (info_ptr == nullptr) {
364 png_destroy_write_struct(&png_ptr, nullptr);
365 return false;
366 }
367
368 int color_type = -1;
369 switch (num_channels) {
370 case 1:
371 color_type = PNG_COLOR_TYPE_GRAY;
372 break;
373 case 2:
374 color_type = PNG_COLOR_TYPE_GRAY_ALPHA;
375 break;
376 case 3:
377 color_type = PNG_COLOR_TYPE_RGB;
378 break;
379 case 4:
380 color_type = PNG_COLOR_TYPE_RGB_ALPHA;
381 break;
382 default:
383 png_destroy_write_struct(&png_ptr, &info_ptr);
384 return false;
385 }
386
387 png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush);
388 if (compression < 0) compression = Z_DEFAULT_COMPRESSION;
389 png_set_compression_level(png_ptr, compression);
390 png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL);
391 // There used to be a call to png_set_filter here turning off filtering
392 // entirely, but it produced pessimal compression ratios. I'm not sure
393 // why it was there.
394 png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type,
395 PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT,
396 PNG_FILTER_TYPE_DEFAULT);
397 // If we have metadata write to it.
398 if (metadata && !metadata->empty()) {
399 std::vector<png_text> text;
400 for (const auto& pair : *metadata) {
401 png_text txt;
402 txt.compression = PNG_TEXT_COMPRESSION_NONE;
403 txt.key = check_metadata_string(pair.first);
404 txt.text = check_metadata_string(pair.second);
405 text.push_back(txt);
406 }
407 png_set_text(png_ptr, info_ptr, &text[0], text.size());
408 }
409
410 png_write_info(png_ptr, info_ptr);
411 if (channel_bits > 8 && port::kLittleEndian) png_set_swap(png_ptr);
412
413 png_byte* row = reinterpret_cast<png_byte*>(const_cast<void*>(image));
414 for (; height--; row += row_bytes) png_write_row(png_ptr, row);
415 png_write_end(png_ptr, nullptr);
416
417 png_destroy_write_struct(&png_ptr, &info_ptr);
418 return true;
419 }
420
421 } // namespace png
422 } // namespace tensorflow
423