1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cstddef>
17 #include <memory>
18 #include <vector>
19 
20 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
21 #include "tensorflow/lite/c/c_api_types.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_register.h"
24 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.h"
25 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27 #include "tensorflow/lite/string_type.h"
28 #include "tensorflow/lite/string_util.h"
29 
30 namespace tflite {
31 namespace acceleration {
32 namespace decode_jpeg_kernel {
33 
Init(TfLiteContext * context,const char * buffer,size_t length)34 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
35   if (!buffer) {
36     return nullptr;
37   }
38 #define RET_ENSURE(context, condition)                                  \
39   do {                                                                  \
40     if (!(condition)) {                                                 \
41       TF_LITE_KERNEL_LOG((context), "%s:%d %s was not true.", __FILE__, \
42                          __LINE__, #condition);                         \
43       return nullptr;                                                   \
44     }                                                                   \
45   } while (0)
46   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
47   const flexbuffers::Map m = flexbuffers::GetRoot(buffer_t, length).AsMap();
48   RET_ENSURE(context, m["height"].IsInt());
49   RET_ENSURE(context, m["width"].IsInt());
50   RET_ENSURE(context, m["num_images"].IsInt());
51   RET_ENSURE(context, m["channels"].IsInt());
52   OpData* op_data = new OpData();
53   op_data->height = m["height"].AsInt32();
54   op_data->width = m["width"].AsInt32();
55   op_data->num_images = m["num_images"].AsInt32();
56   op_data->channels = m["channels"].AsInt32();
57   return op_data;
58 #undef RET_ENSURE
59 }
60 
Free(TfLiteContext * context,void * buffer)61 void Free(TfLiteContext* context, void* buffer) {
62   delete reinterpret_cast<OpData*>(buffer);
63 }
64 
Prepare(TfLiteContext * context,TfLiteNode * node)65 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
66   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
67   TF_LITE_ENSURE(context, op_data);
68   TF_LITE_ENSURE(context, op_data->height > 0);
69   TF_LITE_ENSURE(context, op_data->width > 0);
70   TF_LITE_ENSURE(context, op_data->num_images > 0);
71   // TODO(b/172544567): Support grayscale images.
72   TF_LITE_ENSURE(context, op_data->channels == 3 || op_data->channels == 4);
73 
74   TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
75   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
76 
77   const TfLiteTensor* input_buffer;
78   TF_LITE_ENSURE_OK(context,
79                     GetInputSafe(context, node, /*index=*/0, &input_buffer));
80 
81   TfLiteTensor* output_tensor;
82   TF_LITE_ENSURE_OK(context,
83                     GetOutputSafe(context, node, /*index=*/0, &output_tensor));
84 
85   TF_LITE_ENSURE_TYPES_EQ(context, input_buffer->type, kTfLiteString);
86   TF_LITE_ENSURE_TYPES_EQ(context, output_tensor->type, kTfLiteUInt8);
87 
88   TF_LITE_ENSURE_EQ(context, NumDimensions(input_buffer), 1);
89   TF_LITE_ENSURE_EQ(context, input_buffer->dims->data[0], op_data->num_images);
90 
91   // Resize output.
92   // Output shape is determined as {num_images, height, width, channels}.
93   TfLiteIntArray* new_dims = TfLiteIntArrayCreate(4);
94   new_dims->data[0] = op_data->num_images;
95   new_dims->data[1] = op_data->height;
96   new_dims->data[2] = op_data->width;
97   new_dims->data[3] = op_data->channels;
98   output_tensor->type = kTfLiteUInt8;
99   TF_LITE_ENSURE_OK(context,
100                     context->ResizeTensor(context, output_tensor, new_dims));
101   return kTfLiteOk;
102 }
103 
Eval(TfLiteContext * context,TfLiteNode * node)104 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
105   // Decodes a batch of JPEG images.
106 
107   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
108 
109   const TfLiteTensor* input_buffer;
110   TF_LITE_ENSURE_OK(context,
111                     GetInputSafe(context, node, /*index=*/0, &input_buffer));
112   TF_LITE_ENSURE(context, input_buffer);
113   TF_LITE_ENSURE(context, input_buffer->data.raw);
114   const int channels = op_data->channels;
115   // TODO(b/172544567): Support grayscale images.
116   const int decode_channels = 3;
117   TfLiteTensor* output_tensor;
118   TF_LITE_ENSURE_OK(context,
119                     GetOutputSafe(context, node, /*index=*/0, &output_tensor));
120   // kTfliteUInt8 corresponds to unsigned char as shown in
121   // "tensorflow/lite/portable_type_to_tflitetype.h".
122   unsigned char* output_arr = GetTensorData<unsigned char>(output_tensor);
123   Status decoder_status;
124   std::unique_ptr<LibjpegDecoder> decoder =
125       LibjpegDecoder::Create(decoder_status);
126   if (decoder_status.code != kTfLiteOk) {
127     TF_LITE_KERNEL_LOG(context, decoder_status.error_message.c_str());
128     return kTfLiteError;
129   }
130 
131   const int kDecodedImageSize =
132       op_data->width * op_data->height * decode_channels;
133   const int kOutputImageSize = op_data->width * op_data->height * channels;
134 
135   int output_array_offset = 0;
136   for (int img = 0; img < op_data->num_images; ++img) {
137     tflite::StringRef inputref =
138         tflite::GetString(input_buffer, /*string_index=*/img);
139     unsigned char* decoded = output_arr + output_array_offset;
140 
141     Status decode_status = decoder->DecodeImage(
142         inputref, {op_data->height, op_data->width, decode_channels}, decoded,
143         kDecodedImageSize);
144 
145     if (channels == 4) {
146       // Reorganize the decoded buffer from 3 channels to 4 channels.
147       size_t height = op_data->height;
148       size_t src_offset = kDecodedImageSize;
149       size_t dst_offset = kOutputImageSize;
150       while (height--) {
151         size_t width = op_data->width;
152         while (width--) {
153           src_offset -= decode_channels;
154           dst_offset -= channels;
155           std::copy_n(decoded + src_offset, decode_channels,
156                       decoded + dst_offset);
157           // Add an alpha channel value of 255 (fully opaque) to the
158           // current pixel if the target output channels is provided as 4. This
159           // is a workaround to allow jpeg decoder to work with 4 channel input
160           // models.
161           decoded[dst_offset + 3] = 255;
162         }
163       }
164     }
165 
166     output_array_offset += kOutputImageSize;
167 
168     if (decode_status.code != kTfLiteOk) {
169       TF_LITE_KERNEL_LOG(context, decode_status.error_message.c_str());
170       return kTfLiteError;
171     }
172   }
173   return kTfLiteOk;
174 }
175 
Register_DECODE_JPEG()176 TfLiteRegistration* Register_DECODE_JPEG() {
177   static TfLiteRegistration r = {
178       decode_jpeg_kernel::Init, decode_jpeg_kernel::Free,
179       decode_jpeg_kernel::Prepare, decode_jpeg_kernel::Eval};
180   return &r;
181 }
182 
183 }  // namespace decode_jpeg_kernel
184 }  // namespace acceleration
185 }  // namespace tflite
186