1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <vector>
17
18 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
19 #include "tensorflow/lite/c/c_api_types.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_register.h"
22 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/string_type.h"
26 #include "tensorflow/lite/string_util.h"
27
28 namespace tflite {
29 namespace acceleration {
30 namespace decode_jpeg_kernel {
31
Init(TfLiteContext * context,const char * buffer,size_t length)32 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
33 if (!buffer) {
34 return nullptr;
35 }
36 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
37 const flexbuffers::Map m = flexbuffers::GetRoot(buffer_t, length).AsMap();
38 // TODO(b/172544567): Add error handling for incorrect/missing attributes.
39 OpData* op_data = new OpData();
40 op_data->height = m["height"].AsInt32();
41 op_data->width = m["width"].AsInt32();
42 op_data->num_images = m["num_images"].AsInt32();
43 return op_data;
44 }
45
Free(TfLiteContext * context,void * buffer)46 void Free(TfLiteContext* context, void* buffer) {
47 delete reinterpret_cast<OpData*>(buffer);
48 }
49
Prepare(TfLiteContext * context,TfLiteNode * node)50 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
51 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
52 TF_LITE_ENSURE(context, op_data);
53 TF_LITE_ENSURE(context, op_data->height > 0);
54 TF_LITE_ENSURE(context, op_data->width > 0);
55 TF_LITE_ENSURE(context, op_data->num_images > 0);
56
57 TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
58 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
59
60 const TfLiteTensor* input_buffer;
61 TF_LITE_ENSURE_OK(context,
62 GetInputSafe(context, node, /*index=*/0, &input_buffer));
63
64 TfLiteTensor* output_tensor;
65 TF_LITE_ENSURE_OK(context,
66 GetOutputSafe(context, node, /*index=*/0, &output_tensor));
67
68 TF_LITE_ENSURE_TYPES_EQ(context, input_buffer->type, kTfLiteString);
69 TF_LITE_ENSURE_TYPES_EQ(context, output_tensor->type, kTfLiteUInt8);
70
71 TF_LITE_ENSURE_EQ(context, NumDimensions(input_buffer), 1);
72 TF_LITE_ENSURE_EQ(context, input_buffer->dims->data[0], op_data->num_images);
73
74 // Resize output.
75 // Output shape is determined as {num_images, height, width, channels}.
76 TfLiteIntArray* new_dims = TfLiteIntArrayCreate(4);
77 new_dims->data[0] = op_data->num_images;
78 new_dims->data[1] = op_data->height;
79 new_dims->data[2] = op_data->width;
80 // TODO(b/172544567): Support grayscale images.
81 new_dims->data[3] = 3; // Channels.
82 output_tensor->type = kTfLiteUInt8;
83 TF_LITE_ENSURE_OK(context,
84 context->ResizeTensor(context, output_tensor, new_dims));
85 return kTfLiteOk;
86 }
87
Eval(TfLiteContext * context,TfLiteNode * node)88 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
89 // Decodes a batch of JPEG images.
90
91 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
92
93 const TfLiteTensor* input_buffer;
94 TF_LITE_ENSURE_OK(context,
95 GetInputSafe(context, node, /*index=*/0, &input_buffer));
96 TF_LITE_ENSURE(context, input_buffer);
97 TF_LITE_ENSURE(context, input_buffer->data.raw);
98 TfLiteTensor* output_tensor;
99 TF_LITE_ENSURE_OK(context,
100 GetOutputSafe(context, node, /*index=*/0, &output_tensor));
101 // kTfliteUInt8 corresponds to unsigned char as shown in
102 // "tensorflow/lite/portable_type_to_tflitetype.h".
103 unsigned char* output_arr = GetTensorData<unsigned char>(output_tensor);
104 Status decoder_status;
105 std::unique_ptr<LibjpegDecoder> decoder =
106 LibjpegDecoder::Create(decoder_status);
107 if (decoder_status.code != kTfLiteOk) {
108 TF_LITE_KERNEL_LOG(context, decoder_status.error_message.c_str());
109 return kTfLiteError;
110 }
111
112 const int kImageSize = op_data->width * op_data->height * 3;
113 int output_array_offset = 0;
114 for (int img = 0; img < op_data->num_images; ++img) {
115 tflite::StringRef inputref =
116 tflite::GetString(input_buffer, /*string_index=*/img);
117
118 Status decode_status = decoder->DecodeImage(
119 inputref, {op_data->height, op_data->width, /*channels=*/3},
120 output_arr + output_array_offset, kImageSize);
121
122 output_array_offset += kImageSize;
123
124 if (decode_status.code != kTfLiteOk) {
125 TF_LITE_KERNEL_LOG(context, decode_status.error_message.c_str());
126 return kTfLiteError;
127 }
128 }
129 return kTfLiteOk;
130 }
131
Register_DECODE_JPEG()132 TfLiteRegistration* Register_DECODE_JPEG() {
133 static TfLiteRegistration r = {
134 decode_jpeg_kernel::Init, decode_jpeg_kernel::Free,
135 decode_jpeg_kernel::Prepare, decode_jpeg_kernel::Eval};
136 return &r;
137 }
138
139 } // namespace decode_jpeg_kernel
140 } // namespace acceleration
141 } // namespace tflite
142