1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/image_ops.cc
17
18 #include <cstdint>
19 #include <memory>
20
21 #define EIGEN_USE_THREADS
22
23 #include "absl/strings/escaping.h"
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/gif/gif_io.h"
32 #include "tensorflow/core/lib/jpeg/jpeg_mem.h"
33 #include "tensorflow/core/lib/png/png_io.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/platform/byte_order.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
38
39 namespace tensorflow {
40 namespace {
41
42 // Magic bytes (hex) for each image format.
43 // https://en.wikipedia.org/wiki/List_of_file_signatures
44 // WARNING: Changing `static const` to `constexpr` requires first checking that
45 // it works with supported MSVC version.
46 // https://docs.microsoft.com/en-us/cpp/cpp/constexpr-cpp?redirectedfrom=MSDN&view=vs-2019
47 static const char kPngMagicBytes[] = "\x89\x50\x4E\x47\x0D\x0A\x1A\x0A";
48 static const char kGifMagicBytes[] = "\x47\x49\x46\x38";
49 static const char kBmpMagicBytes[] = "\x42\x4d";
50 // The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three.
51 static const char kJpegMagicBytes[] = "\xff\xd8\xff";
52
53 enum FileFormat {
54 kUnknownFormat = 0,
55 kPngFormat = 1,
56 kJpgFormat = 2,
57 kGifFormat = 3,
58 kBmpFormat = 4,
59 };
60
61 // Classify the contents of a file based on starting bytes (the magic number).
ClassifyFileFormat(StringPiece data)62 FileFormat ClassifyFileFormat(StringPiece data) {
63 if (absl::StartsWith(data, kJpegMagicBytes)) return kJpgFormat;
64 if (absl::StartsWith(data, kPngMagicBytes)) return kPngFormat;
65 if (absl::StartsWith(data, kGifMagicBytes)) return kGifFormat;
66 if (absl::StartsWith(data, kBmpMagicBytes)) return kBmpFormat;
67 return kUnknownFormat;
68 }
69
70 // Decode an image. Supported image formats are JPEG, PNG, GIF and BMP. This is
71 // a newer version of `DecodeImageOp` for enabling image data parsing to take
72 // place in kernels only, reducing security vulnerabilities and redundancy.
73 class DecodeImageV2Op : public OpKernel {
74 public:
DecodeImageV2Op(OpKernelConstruction * context)75 explicit DecodeImageV2Op(OpKernelConstruction* context) : OpKernel(context) {
76 // Keep track of op string information because:
77 // [1] Currently by the API, PNG, JPEG and GIF can decode each other and
78 // depending on the op type, we need to return either 3-D or 4-D shapes.
79 // [2] Different ops have different attributes. e.g. `DecodeImage` op has
80 // `expand_animations` attribute that other ops don't.
81 // `DecodeAndDropJpeg` also has additional attributes.
82 op_type_ = type_string();
83
84 // Validate op type.
85 OP_REQUIRES(context,
86 op_type_ == "DecodeJpeg" || op_type_ == "DecodeAndCropJpeg" ||
87 op_type_ == "DecodePng" || op_type_ == "DecodeGif" ||
88 op_type_ == "DecodeBmp" || op_type_ == "DecodeImage",
89 errors::InvalidArgument("Bad op type ", op_type_));
90
91 // Get attributes from `DecodeJpeg` and `DecodeAndCropJpeg` op
92 // invocations. For `DecodeImage` op, set JPEG decoding setting to TF
93 // default.
94 if (op_type_ == "DecodeJpeg" || op_type_ == "DecodeAndCropJpeg") {
95 OP_REQUIRES_OK(context, context->GetAttr("ratio", &flags_.ratio));
96 OP_REQUIRES(context,
97 flags_.ratio == 1 || flags_.ratio == 2 || flags_.ratio == 4 ||
98 flags_.ratio == 8,
99 errors::InvalidArgument("ratio must be 1, 2, 4, or 8, got ",
100 flags_.ratio));
101 OP_REQUIRES_OK(context, context->GetAttr("fancy_upscaling",
102 &flags_.fancy_upscaling));
103 OP_REQUIRES_OK(context,
104 context->GetAttr("try_recover_truncated",
105 &flags_.try_recover_truncated_jpeg));
106 OP_REQUIRES_OK(context,
107 context->GetAttr("acceptable_fraction",
108 &flags_.min_acceptable_fraction));
109 string dct_method;
110 OP_REQUIRES_OK(context, context->GetAttr("dct_method", &dct_method));
111 OP_REQUIRES(
112 context,
113 (dct_method.empty() || dct_method == "INTEGER_FAST" ||
114 dct_method == "INTEGER_ACCURATE"),
115 errors::InvalidArgument("dct_method must be one of "
116 "{'', 'INTEGER_FAST', 'INTEGER_ACCURATE'}"));
117 // The TensorFlow-chosen default for JPEG decoding is IFAST, sacrificing
118 // image quality for speed.
119 if (dct_method.empty() || dct_method == "INTEGER_FAST") {
120 flags_.dct_method = JDCT_IFAST;
121 } else if (dct_method == "INTEGER_ACCURATE") {
122 flags_.dct_method = JDCT_ISLOW;
123 }
124 } else {
125 flags_ = jpeg::UncompressFlags();
126 flags_.dct_method = JDCT_IFAST;
127 }
128
129 // Get `dtype` attribute from `DecodePng` or `DecodeImage` op invocations.
130 if (op_type_ == "DecodePng" || op_type_ == "DecodeImage") {
131 OP_REQUIRES_OK(context, context->GetAttr("dtype", &data_type_));
132 if (op_type_ == "DecodePng") {
133 OP_REQUIRES(
134 context,
135 data_type_ == DataType::DT_UINT8 ||
136 data_type_ == DataType::DT_UINT16,
137 errors::InvalidArgument(
138 "`dtype` for `DecodePng` must be unit8, unit16 but got: ",
139 data_type_));
140 } else {
141 OP_REQUIRES(context,
142 data_type_ == DataType::DT_UINT8 ||
143 data_type_ == DataType::DT_UINT16 ||
144 data_type_ == DataType::DT_FLOAT,
145 errors::InvalidArgument("`dtype` for `DecodeImage` must be "
146 "unit8, unit16, float but got: ",
147 data_type_));
148 OP_REQUIRES_OK(context, context->GetAttr("expand_animations",
149 &expand_animations_));
150 }
151 }
152
153 // Get `channels` attribute for all ops except `DecodeGif` op.
154 // `DecodeGif` doesn't have `channels` attribute but it supports 3
155 // channels by default.
156 if (op_type_ != "DecodeGif") {
157 OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_));
158 OP_REQUIRES(
159 context,
160 channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4,
161 errors::InvalidArgument("`channels` must be 0, 1, 3 or 4 but got ",
162 channels_));
163 } else {
164 channels_ = 3;
165 }
166 }
167
168 // Helper for decoding BMP.
ByteSwapInt32ForBigEndian(int32 x)169 inline int32 ByteSwapInt32ForBigEndian(int32 x) {
170 if (!port::kLittleEndian) {
171 return BYTE_SWAP_32(x);
172 } else {
173 return x;
174 }
175 }
176
177 // Helper for decoding BMP.
ByteSwapInt16ForBigEndian(int16 x)178 inline int16 ByteSwapInt16ForBigEndian(int16 x) {
179 if (!port::kLittleEndian) {
180 return BYTE_SWAP_16(x);
181 } else {
182 return x;
183 }
184 }
185
Compute(OpKernelContext * context)186 void Compute(OpKernelContext* context) override {
187 const Tensor& contents = context->input(0);
188 OP_REQUIRES(
189 context, TensorShapeUtils::IsScalar(contents.shape()),
190 errors::InvalidArgument("`contents` must be scalar but got shape",
191 contents.shape().DebugString()));
192 const StringPiece input = contents.scalar<tstring>()();
193 OP_REQUIRES(context, !input.empty(),
194 errors::InvalidArgument("Input is empty."));
195 OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(),
196 errors::InvalidArgument(
197 "Input contents are too large for int: ", input.size()));
198
199 // Parse magic bytes to determine file format.
200 switch (ClassifyFileFormat(input)) {
201 case kJpgFormat:
202 DecodeJpegV2(context, input);
203 break;
204 case kPngFormat:
205 DecodePngV2(context, input);
206 break;
207 case kGifFormat:
208 DecodeGifV2(context, input);
209 break;
210 case kBmpFormat:
211 DecodeBmpV2(context, input);
212 break;
213 case kUnknownFormat:
214 OP_REQUIRES(context, false,
215 errors::InvalidArgument("Unknown image file format. One of "
216 "JPEG, PNG, GIF, BMP required."));
217 break;
218 }
219 }
220
DecodeJpegV2(OpKernelContext * context,StringPiece input)221 void DecodeJpegV2(OpKernelContext* context, StringPiece input) {
222 OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3,
223 errors::InvalidArgument("JPEG does not support 4 channels"));
224
225 // Use local copy of flags to avoid race condition as the class member is
226 // shared among different invocations.
227 jpeg::UncompressFlags flags = flags_;
228 flags.components = channels_;
229
230 if (op_type_ == "DecodeAndCropJpeg") {
231 flags.crop = true;
232 // Update flags to include crop window.
233 const Tensor& crop_window = context->input(1);
234 OP_REQUIRES(context, crop_window.dims() == 1,
235 errors::InvalidArgument("crop_window must be 1-D, got shape ",
236 crop_window.shape().DebugString()));
237 OP_REQUIRES(context, crop_window.dim_size(0) == 4,
238 errors::InvalidArgument("crop_size must have four elements ",
239 crop_window.shape().DebugString()));
240 auto crop_window_vec = crop_window.vec<int32>();
241 flags.crop_y = crop_window_vec(0);
242 flags.crop_x = crop_window_vec(1);
243 flags.crop_height = crop_window_vec(2);
244 flags.crop_width = crop_window_vec(3);
245 } else if (op_type_ == "DecodeBmp") {
246 // TODO(b/171060723): Only DecodeBmp as op_type_ is not acceptable here
247 // because currently `decode_(jpeg|png|gif)` ops can decode any one of
248 // jpeg, png or gif but not bmp. Similarly, `decode_bmp` cannot decode
249 // anything but bmp formats. This behavior needs to be revisited. For more
250 // details, please refer to the bug.
251 OP_REQUIRES(context, false,
252 errors::InvalidArgument(
253 "Trying to decode JPEG format using DecodeBmp op. Use "
254 "`decode_jpeg` or `decode_image` instead."));
255 }
256
257 // Output tensor and the image buffer size.
258 Tensor* output = nullptr;
259 int buffer_size = 0;
260
261 // Decode JPEG. Directly allocate to the output buffer if data type is
262 // uint8 (to save extra copying). Otherwise, allocate a new uint8 buffer
263 // with buffer size. `jpeg::Uncompress` supports unit8 only.
264 uint8* buffer = jpeg::Uncompress(
265 input.data(), input.size(), flags, nullptr /* nwarn */,
266 [&](int width, int height, int channels) -> uint8* {
267 buffer_size = height * width * channels;
268 Status status;
269 // By the existing API, we support decoding JPEG with `DecodeGif`
270 // op. We need to make sure to return 4-D shapes when using
271 // `DecodeGif`.
272 if (op_type_ == "DecodeGif") {
273 status = context->allocate_output(
274 0, TensorShape({1, height, width, channels}), &output);
275 } else {
276 status = context->allocate_output(
277 0, TensorShape({height, width, channels}), &output);
278 }
279 if (!status.ok()) {
280 VLOG(1) << status;
281 context->SetStatus(status);
282 return nullptr;
283 }
284
285 if (data_type_ == DataType::DT_UINT8) {
286 return output->flat<uint8>().data();
287 } else {
288 return new uint8[buffer_size];
289 }
290 });
291
292 OP_REQUIRES(
293 context, buffer,
294 errors::InvalidArgument(
295 "jpeg::Uncompress failed. Invalid JPEG data or crop window."));
296
297 // For when desired data type if unit8, the output buffer is already
298 // allocated during the `jpeg::Uncompress` call above; return.
299 if (data_type_ == DataType::DT_UINT8) {
300 return;
301 }
302 // Make sure we don't forget to deallocate `buffer`.
303 std::unique_ptr<uint8[]> buffer_unique_ptr(buffer);
304
305 // Convert uint8 image data to desired data type.
306 // Use eigen threadpooling to speed up the copy operation.
307 const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
308 TTypes<uint8>::UnalignedConstFlat buffer_view(buffer, buffer_size);
309 if (data_type_ == DataType::DT_UINT16) {
310 uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
311 (std::numeric_limits<uint8>::max() + 1));
312 // Fill output tensor with desired dtype.
313 output->flat<uint16>().device(device) =
314 buffer_view.cast<uint16>() * scale;
315 } else if (data_type_ == DataType::DT_FLOAT) {
316 float scale = 1. / std::numeric_limits<uint8>::max();
317 // Fill output tensor with desired dtype.
318 output->flat<float>().device(device) = buffer_view.cast<float>() * scale;
319 }
320 }
321
DecodePngV2(OpKernelContext * context,StringPiece input)322 void DecodePngV2(OpKernelContext* context, StringPiece input) {
323 int channel_bits = (data_type_ == DataType::DT_UINT8) ? 8 : 16;
324 png::DecodeContext decode;
325 OP_REQUIRES(
326 context, png::CommonInitDecode(input, channels_, channel_bits, &decode),
327 errors::InvalidArgument("Invalid PNG. Failed to initialize decoder."));
328
329 // Verify that width and height are not too large:
330 // - verify width and height don't overflow int.
331 // - width can later be multiplied by channels_ and sizeof(uint16), so
332 // verify single dimension is not too large.
333 // - verify when width and height are multiplied together, there are a few
334 // bits to spare as well.
335 const int width = static_cast<int>(decode.width);
336 const int height = static_cast<int>(decode.height);
337 const int64 total_size =
338 static_cast<int64>(width) * static_cast<int64>(height);
339 if (width != static_cast<int64>(decode.width) || width <= 0 ||
340 width >= (1LL << 27) || height != static_cast<int64>(decode.height) ||
341 height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) {
342 png::CommonFreeDecode(&decode);
343 OP_REQUIRES(context, false,
344 errors::InvalidArgument("PNG size too large for int: ",
345 decode.width, " by ", decode.height));
346 }
347
348 Tensor* output = nullptr;
349 Status status;
350 // By the existing API, we support decoding PNG with `DecodeGif` op.
351 // We need to make sure to return 4-D shapes when using `DecodeGif`.
352 if (op_type_ == "DecodeGif") {
353 status = context->allocate_output(
354 0, TensorShape({1, height, width, decode.channels}), &output);
355 } else {
356 status = context->allocate_output(
357 0, TensorShape({height, width, decode.channels}), &output);
358 }
359
360 if (op_type_ == "DecodeBmp") {
361 // TODO(b/171060723): Only DecodeBmp as op_type_ is not acceptable here
362 // because currently `decode_(jpeg|png|gif)` ops can decode any one of
363 // jpeg, png or gif but not bmp. Similarly, `decode_bmp` cannot decode
364 // anything but bmp formats. This behavior needs to be revisited. For more
365 // details, please refer to the bug.
366 OP_REQUIRES(context, false,
367 errors::InvalidArgument(
368 "Trying to decode PNG format using DecodeBmp op. Use "
369 "`decode_png` or `decode_image` instead."));
370 } else if (op_type_ == "DecodeAndCropJpeg") {
371 OP_REQUIRES(context, false,
372 errors::InvalidArgument(
373 "DecodeAndCropJpeg operation can run on JPEG only, but "
374 "detected PNG."));
375 }
376
377 if (!status.ok()) png::CommonFreeDecode(&decode);
378 OP_REQUIRES_OK(context, status);
379
380 if (data_type_ == DataType::DT_UINT8) {
381 OP_REQUIRES(
382 context,
383 png::CommonFinishDecode(
384 reinterpret_cast<png_bytep>(output->flat<uint8>().data()),
385 decode.channels * width * sizeof(uint8), &decode),
386 errors::InvalidArgument("Invalid PNG data, size ", input.size()));
387 } else if (data_type_ == DataType::DT_UINT16) {
388 OP_REQUIRES(
389 context,
390 png::CommonFinishDecode(
391 reinterpret_cast<png_bytep>(output->flat<uint16>().data()),
392 decode.channels * width * sizeof(uint16), &decode),
393 errors::InvalidArgument("Invalid PNG data, size ", input.size()));
394 } else if (data_type_ == DataType::DT_FLOAT) {
395 // `png::CommonFinishDecode` does not support `float`. First allocate
396 // uint16 buffer for the image and decode in uint16 (lossless). Wrap the
397 // buffer in `unique_ptr` so that we don't forget to delete the buffer.
398 std::unique_ptr<uint16[]> buffer(
399 new uint16[height * width * decode.channels]);
400 OP_REQUIRES(
401 context,
402 png::CommonFinishDecode(reinterpret_cast<png_bytep>(buffer.get()),
403 decode.channels * width * sizeof(uint16),
404 &decode),
405 errors::InvalidArgument("Invalid PNG data, size ", input.size()));
406
407 // Convert uint16 image data to desired data type.
408 // Use eigen threadpooling to speed up the copy operation.
409 const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
410 TTypes<uint16, 3>::UnalignedConstTensor buf(buffer.get(), height, width,
411 decode.channels);
412 float scale = 1. / std::numeric_limits<uint16>::max();
413 // Fill output tensor with desired dtype.
414 output->tensor<float, 3>().device(device) = buf.cast<float>() * scale;
415 }
416 }
417
DecodeGifV2(OpKernelContext * context,StringPiece input)418 void DecodeGifV2(OpKernelContext* context, StringPiece input) {
419 // GIF has 3 channels.
420 OP_REQUIRES(context, channels_ == 0 || channels_ == 3,
421 errors::InvalidArgument("channels must be 0 or 3 for GIF, got ",
422 channels_));
423
424 if (op_type_ == "DecodeBmp") {
425 // TODO(b/171060723): Only DecodeBmp as op_type_ is not acceptable here
426 // because currently `decode_(jpeg|png|gif)` ops can decode any one of
427 // jpeg, png or gif but not bmp. Similarly, `decode_bmp` cannot decode
428 // anything but bmp formats. This behavior needs to be revisited. For more
429 // details, please refer to the bug.
430 OP_REQUIRES(context, false,
431 errors::InvalidArgument(
432 "Trying to decode GIF format using DecodeBmp op. Use "
433 "`decode_gif` or `decode_image` instead."));
434 } else if (op_type_ == "DecodeAndCropJpeg") {
435 OP_REQUIRES(context, false,
436 errors::InvalidArgument(
437 "DecodeAndCropJpeg operation can run on JPEG only, but "
438 "detected GIF."));
439 }
440
441 // Decode GIF, allocating tensor if dtype is uint8, otherwise defer tensor
442 // allocation til after dtype conversion is done. `gif`::Decode` supports
443 // uint8 only.
444 Tensor* output = nullptr;
445 int buffer_size = 0;
446 string error_string;
447 uint8* buffer = gif::Decode(
448 input.data(), input.size(),
449 [&](int num_frames, int width, int height, int channels) -> uint8* {
450 buffer_size = num_frames * height * width * channels;
451
452 Status status;
453 // By the existing API, we support decoding GIF with `decode_jpeg` or
454 // with `decode_png` if the GIF is a single-frame GIF (non-animated).
455 // We need to make sure to return 3-D shapes when using in this case.
456 if (op_type_ == "DecodePng" || op_type_ == "DecodeJpeg") {
457 if (num_frames == 1) {
458 status = context->allocate_output(
459 0, TensorShape({height, width, channels}), &output);
460 } else {
461 status = errors::InvalidArgument(
462 "Got ", num_frames, " frames, but animated gifs ",
463 "can only be decoded by tf.io.decode_gif or ",
464 "tf.io.decode_image");
465 }
466 } else if (op_type_ == "DecodeGif" ||
467 (op_type_ == "DecodeImage" && expand_animations_)) {
468 status = context->allocate_output(
469 0, TensorShape({num_frames, height, width, channels}), &output);
470 } else if (op_type_ == "DecodeImage" && !expand_animations_) {
471 status = context->allocate_output(
472 0, TensorShape({height, width, channels}), &output);
473 } else {
474 status = errors::InvalidArgument("Bad op type ", op_type_);
475 }
476 if (!status.ok()) {
477 VLOG(1) << status;
478 context->SetStatus(status);
479 return nullptr;
480 }
481
482 if (data_type_ == DataType::DT_UINT8) {
483 return output->flat<uint8>().data();
484 } else {
485 return new uint8[buffer_size];
486 }
487 },
488 &error_string, expand_animations_);
489
490 OP_REQUIRES(context, buffer,
491 errors::InvalidArgument("Invalid GIF data (size ", input.size(),
492 "), ", error_string));
493
494 // For when desired data type is unit8, the output buffer is already
495 // allocated during the `gif::Decode` call above; return.
496 if (data_type_ == DataType::DT_UINT8) {
497 return;
498 }
499 // Make sure we don't forget to deallocate `buffer`.
500 std::unique_ptr<uint8[]> buffer_unique_ptr(buffer);
501
502 // Convert the raw uint8 buffer to desired dtype.
503 // Use eigen threadpooling to speed up the copy operation.
504 TTypes<uint8>::UnalignedConstFlat buffer_view(buffer, buffer_size);
505 const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
506 if (data_type_ == DataType::DT_UINT16) {
507 uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
508 (std::numeric_limits<uint8>::max() + 1));
509 // Fill output tensor with desired dtype.
510 output->flat<uint16>().device(device) =
511 buffer_view.cast<uint16>() * scale;
512 } else if (data_type_ == DataType::DT_FLOAT) {
513 float scale = 1. / std::numeric_limits<uint8>::max();
514 // Fill output tensor with desired dtype.
515 output->flat<float>().device(device) = buffer_view.cast<float>() * scale;
516 }
517 }
518
DecodeBmpV2(OpKernelContext * context,StringPiece input)519 void DecodeBmpV2(OpKernelContext* context, StringPiece input) {
520 OP_REQUIRES(
521 context, channels_ != 1,
522 errors::InvalidArgument(
523 "`channels` must be 0, 3 or 4 for BMP, but got ", channels_));
524
525 if (op_type_ != "DecodeBmp" && op_type_ != "DecodeImage") {
526 if (op_type_ == "DecodeAndCropJpeg") {
527 OP_REQUIRES(context, false,
528 errors::InvalidArgument(
529 "DecodeAndCropJpeg operation can run on JPEG only, but "
530 "detected BMP."));
531 } else {
532 OP_REQUIRES(context, false,
533 errors::InvalidArgument(
534 "Trying to decode BMP format using a wrong op. Use "
535 "`decode_bmp` or `decode_image` instead. Op used: ",
536 op_type_));
537 }
538 }
539
540 OP_REQUIRES(context, (32 <= input.size()),
541 errors::InvalidArgument("Incomplete bmp content, requires at "
542 "least 32 bytes to find the header "
543 "size, width, height, and bpp, got ",
544 input.size(), " bytes"));
545
546 const uint8* img_bytes = reinterpret_cast<const uint8*>(input.data());
547 int32 header_size_ = internal::SubtleMustCopy(
548 *(reinterpret_cast<const int32*>(img_bytes + 10)));
549 const int32 header_size = ByteSwapInt32ForBigEndian(header_size_);
550 int32 width_ = internal::SubtleMustCopy(
551 *(reinterpret_cast<const int32*>(img_bytes + 18)));
552 const int32 width = ByteSwapInt32ForBigEndian(width_);
553 int32 height_ = internal::SubtleMustCopy(
554 *(reinterpret_cast<const int32*>(img_bytes + 22)));
555 const int32 height = ByteSwapInt32ForBigEndian(height_);
556 int16 bpp_ = internal::SubtleMustCopy(
557 *(reinterpret_cast<const int16*>(img_bytes + 28)));
558 const int16 bpp = ByteSwapInt16ForBigEndian(bpp_);
559
560 // `channels_` is desired number of channels. `img_channels` is number of
561 // channels inherent in the image.
562 int img_channels = bpp / 8;
563 OP_REQUIRES(
564 context, (img_channels == 1 || img_channels == 3 || img_channels == 4),
565 errors::InvalidArgument(
566 "Number of channels inherent in the image must be 1, 3 or 4, was ",
567 img_channels));
568 const int requested_channels = channels_ ? channels_ : img_channels;
569
570 OP_REQUIRES(context, width > 0,
571 errors::InvalidArgument("Width must be positive"));
572 OP_REQUIRES(context, height != 0,
573 errors::InvalidArgument("Height must be nonzero"));
574 OP_REQUIRES(context, header_size >= 0,
575 errors::InvalidArgument("header size must be nonnegative"));
576
577 // The real requirement is < 2^31 minus some headers and channel data,
578 // so rounding down to something that's still ridiculously big.
579 OP_REQUIRES(
580 context,
581 (static_cast<int64>(width) * std::abs(static_cast<int64>(height))) <
582 static_cast<int64>(std::numeric_limits<int32_t>::max() / 8),
583 errors::InvalidArgument(
584 "Total possible pixel bytes must be less than 2^30"));
585
586 const int32 abs_height = abs(height);
587
588 // there may be padding bytes when the width is not a multiple of 4 bytes
589 const int row_size = (img_channels * width + 3) / 4 * 4;
590
591 // Make sure the size of input data matches up with the total size of
592 // headers plus height * row_size.
593 int size_diff = input.size() - header_size - (row_size * abs_height);
594 OP_REQUIRES(
595 context, size_diff == 0,
596 errors::InvalidArgument(
597 "Input size should match (header_size + row_size * abs_height) but "
598 "they differ by ",
599 size_diff));
600
601 const int64 last_pixel_offset = static_cast<int64>(header_size) +
602 (abs_height - 1) * row_size +
603 (width - 1) * img_channels;
604
605 // [expected file size] = [last pixel offset] + [last pixel size=channels]
606 const int64 expected_file_size = last_pixel_offset + img_channels;
607
608 OP_REQUIRES(
609 context, (expected_file_size <= input.size()),
610 errors::InvalidArgument("Incomplete bmp content, requires at least ",
611 expected_file_size, " bytes, got ",
612 input.size(), " bytes"));
613
614 // if height is negative, data layout is top down
615 // otherwise, it's bottom up.
616 bool top_down = (height < 0);
617
618 // Decode image, allocating tensor once the image size is known.
619 Tensor* output = nullptr;
620 OP_REQUIRES_OK(
621 context,
622 context->allocate_output(
623 0, TensorShape({abs_height, width, requested_channels}), &output));
624
625 const uint8* bmp_pixels = &img_bytes[header_size];
626
627 if (data_type_ == DataType::DT_UINT8) {
628 DecodeBMP(bmp_pixels, row_size, output->flat<uint8>().data(), width,
629 abs_height, requested_channels, img_channels, top_down);
630 } else {
631 std::unique_ptr<uint8[]> buffer(
632 new uint8[height * width * requested_channels]);
633 DecodeBMP(bmp_pixels, row_size, buffer.get(), width, abs_height,
634 requested_channels, img_channels, top_down);
635 TTypes<uint8, 3>::UnalignedConstTensor buf(buffer.get(), height, width,
636 requested_channels);
637 // Convert the raw uint8 buffer to desired dtype.
638 // Use eigen threadpooling to speed up the copy operation.
639 const auto& device = context->eigen_device<Eigen::ThreadPoolDevice>();
640 if (data_type_ == DataType::DT_UINT16) {
641 uint16 scale = floor((std::numeric_limits<uint16>::max() + 1) /
642 (std::numeric_limits<uint8>::max() + 1));
643 // Fill output tensor with desired dtype.
644 output->tensor<uint16, 3>().device(device) = buf.cast<uint16>() * scale;
645 } else if (data_type_ == DataType::DT_FLOAT) {
646 float scale = 1. / std::numeric_limits<uint8>::max();
647 // Fill output tensor with desired dtype.
648 output->tensor<float, 3>().device(device) = buf.cast<float>() * scale;
649 }
650 }
651 }
652
653 private:
654 void DecodeBMP(const uint8* input, const int row_size, uint8* const output,
655 const int width, const int height, const int output_channels,
656 const int input_channels, bool top_down);
657
658 int channels_ = 0;
659 DataType data_type_ = DataType::DT_UINT8;
660 bool expand_animations_ = true;
661 jpeg::UncompressFlags flags_;
662 string op_type_;
663 };
664
665 REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeImageV2Op);
666 REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodeImageV2Op);
667 REGISTER_KERNEL_BUILDER(Name("DecodeGif").Device(DEVICE_CPU), DecodeImageV2Op);
668 REGISTER_KERNEL_BUILDER(Name("DecodeAndCropJpeg").Device(DEVICE_CPU),
669 DecodeImageV2Op);
670 REGISTER_KERNEL_BUILDER(Name("DecodeImage").Device(DEVICE_CPU),
671 DecodeImageV2Op);
672 REGISTER_KERNEL_BUILDER(Name("DecodeBmp").Device(DEVICE_CPU), DecodeImageV2Op);
673
DecodeBMP(const uint8 * input,const int row_size,uint8 * const output,const int width,const int height,const int output_channels,const int input_channels,bool top_down)674 void DecodeImageV2Op::DecodeBMP(const uint8* input, const int row_size,
675 uint8* const output, const int width,
676 const int height, const int output_channels,
677 const int input_channels, bool top_down) {
678 for (int i = 0; i < height; i++) {
679 int src_pos;
680 int dst_pos;
681
682 for (int j = 0; j < width; j++) {
683 if (!top_down) {
684 src_pos = ((height - 1 - i) * row_size) + j * input_channels;
685 } else {
686 src_pos = i * row_size + j * input_channels;
687 }
688
689 dst_pos = (i * width + j) * output_channels;
690
691 switch (input_channels) {
692 case 1:
693 output[dst_pos] = input[src_pos];
694 // Set 2nd and 3rd channels if user requested for 3 or 4 channels.
695 // Repeat 1st channel's value.
696 if (output_channels == 3 || output_channels == 4) {
697 output[dst_pos + 1] = input[src_pos];
698 output[dst_pos + 2] = input[src_pos];
699 }
700 // Set 4th channel (alpha) to maximum value if user requested for
701 // 4 channels.
702 if (output_channels == 4) {
703 output[dst_pos + 3] = UINT8_MAX;
704 }
705 break;
706 case 3:
707 // BGR -> RGB
708 output[dst_pos] = input[src_pos + 2];
709 output[dst_pos + 1] = input[src_pos + 1];
710 output[dst_pos + 2] = input[src_pos];
711 // Set 4th channel (alpha) to maximum value if the user requested for
712 // 4 channels and the input image has 3 channels only.
713 if (output_channels == 4) {
714 output[dst_pos + 3] = UINT8_MAX;
715 }
716 break;
717 case 4:
718 // BGRA -> RGBA
719 output[dst_pos] = input[src_pos + 2];
720 output[dst_pos + 1] = input[src_pos + 1];
721 output[dst_pos + 2] = input[src_pos];
722 // Set 4th channel only if the user requested for 4 channels. If not,
723 // then user requested 3 channels; skip this step.
724 if (output_channels == 4) {
725 output[dst_pos + 3] = input[src_pos + 3];
726 }
727 break;
728 default:
729 LOG(FATAL) << "Unexpected number of channels: " << input_channels;
730 break;
731 }
732 }
733 }
734 }
735
736 } // namespace
737 } // namespace tensorflow
738