• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <vector>
17 
18 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
19 #include "tensorflow/lite/c/c_api_types.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/decode_jpeg_register.h"
22 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/libjpeg_decoder.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/string_type.h"
26 #include "tensorflow/lite/string_util.h"
27 
28 namespace tflite {
29 namespace acceleration {
30 namespace decode_jpeg_kernel {
31 
Init(TfLiteContext * context,const char * buffer,size_t length)32 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
33   if (!buffer) {
34     return nullptr;
35   }
36   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
37   const flexbuffers::Map m = flexbuffers::GetRoot(buffer_t, length).AsMap();
38   // TODO(b/172544567): Add error handling for incorrect/missing attributes.
39   OpData* op_data = new OpData();
40   op_data->height = m["height"].AsInt32();
41   op_data->width = m["width"].AsInt32();
42   op_data->num_images = m["num_images"].AsInt32();
43   return op_data;
44 }
45 
Free(TfLiteContext * context,void * buffer)46 void Free(TfLiteContext* context, void* buffer) {
47   delete reinterpret_cast<OpData*>(buffer);
48 }
49 
Prepare(TfLiteContext * context,TfLiteNode * node)50 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
51   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
52   TF_LITE_ENSURE(context, op_data);
53   TF_LITE_ENSURE(context, op_data->height > 0);
54   TF_LITE_ENSURE(context, op_data->width > 0);
55   TF_LITE_ENSURE(context, op_data->num_images > 0);
56 
57   TF_LITE_ENSURE_EQ(context, node->inputs->size, 1);
58   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
59 
60   const TfLiteTensor* input_buffer;
61   TF_LITE_ENSURE_OK(context,
62                     GetInputSafe(context, node, /*index=*/0, &input_buffer));
63 
64   TfLiteTensor* output_tensor;
65   TF_LITE_ENSURE_OK(context,
66                     GetOutputSafe(context, node, /*index=*/0, &output_tensor));
67 
68   TF_LITE_ENSURE_TYPES_EQ(context, input_buffer->type, kTfLiteString);
69   TF_LITE_ENSURE_TYPES_EQ(context, output_tensor->type, kTfLiteUInt8);
70 
71   TF_LITE_ENSURE_EQ(context, NumDimensions(input_buffer), 1);
72   TF_LITE_ENSURE_EQ(context, input_buffer->dims->data[0], op_data->num_images);
73 
74   // Resize output.
75   // Output shape is determined as {num_images, height, width, channels}.
76   TfLiteIntArray* new_dims = TfLiteIntArrayCreate(4);
77   new_dims->data[0] = op_data->num_images;
78   new_dims->data[1] = op_data->height;
79   new_dims->data[2] = op_data->width;
80   // TODO(b/172544567): Support grayscale images.
81   new_dims->data[3] = 3;  // Channels.
82   output_tensor->type = kTfLiteUInt8;
83   TF_LITE_ENSURE_OK(context,
84                     context->ResizeTensor(context, output_tensor, new_dims));
85   return kTfLiteOk;
86 }
87 
Eval(TfLiteContext * context,TfLiteNode * node)88 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
89   // Decodes a batch of JPEG images.
90 
91   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
92 
93   const TfLiteTensor* input_buffer;
94   TF_LITE_ENSURE_OK(context,
95                     GetInputSafe(context, node, /*index=*/0, &input_buffer));
96   TF_LITE_ENSURE(context, input_buffer);
97   TF_LITE_ENSURE(context, input_buffer->data.raw);
98   TfLiteTensor* output_tensor;
99   TF_LITE_ENSURE_OK(context,
100                     GetOutputSafe(context, node, /*index=*/0, &output_tensor));
101   // kTfliteUInt8 corresponds to unsigned char as shown in
102   // "tensorflow/lite/portable_type_to_tflitetype.h".
103   unsigned char* output_arr = GetTensorData<unsigned char>(output_tensor);
104   Status decoder_status;
105   std::unique_ptr<LibjpegDecoder> decoder =
106       LibjpegDecoder::Create(decoder_status);
107   if (decoder_status.code != kTfLiteOk) {
108     TF_LITE_KERNEL_LOG(context, decoder_status.error_message.c_str());
109     return kTfLiteError;
110   }
111 
112   const int kImageSize = op_data->width * op_data->height * 3;
113   int output_array_offset = 0;
114   for (int img = 0; img < op_data->num_images; ++img) {
115     tflite::StringRef inputref =
116         tflite::GetString(input_buffer, /*string_index=*/img);
117 
118     Status decode_status = decoder->DecodeImage(
119         inputref, {op_data->height, op_data->width, /*channels=*/3},
120         output_arr + output_array_offset, kImageSize);
121 
122     output_array_offset += kImageSize;
123 
124     if (decode_status.code != kTfLiteOk) {
125       TF_LITE_KERNEL_LOG(context, decode_status.error_message.c_str());
126       return kTfLiteError;
127     }
128   }
129   return kTfLiteOk;
130 }
131 
Register_DECODE_JPEG()132 TfLiteRegistration* Register_DECODE_JPEG() {
133   static TfLiteRegistration r = {
134       decode_jpeg_kernel::Init, decode_jpeg_kernel::Free,
135       decode_jpeg_kernel::Prepare, decode_jpeg_kernel::Eval};
136   return &r;
137 }
138 
139 }  // namespace decode_jpeg_kernel
140 }  // namespace acceleration
141 }  // namespace tflite
142