• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <jni.h>
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow_lite_support/cc/port/statusor.h"
23 #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
24 #include "tensorflow_lite_support/cc/task/vision/image_segmenter.h"
25 #include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h"
26 #include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h"
27 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
28 #include "tensorflow_lite_support/cc/utils/jni_utils.h"
29 #include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h"
30 
31 namespace {
32 
33 using ::tflite::support::StatusOr;
34 using ::tflite::support::utils::CreateByteArray;
35 using ::tflite::support::utils::GetMappedFileBuffer;
36 using ::tflite::support::utils::kAssertionError;
37 using ::tflite::support::utils::kIllegalArgumentException;
38 using ::tflite::support::utils::kInvalidPointer;
39 using ::tflite::support::utils::ThrowException;
40 using ::tflite::task::vision::ConvertToFrameBufferOrientation;
41 using ::tflite::task::vision::FrameBuffer;
42 using ::tflite::task::vision::ImageSegmenter;
43 using ::tflite::task::vision::ImageSegmenterOptions;
44 using ::tflite::task::vision::Segmentation;
45 using ::tflite::task::vision::SegmentationResult;
46 
47 constexpr char kArrayListClassNameNoSig[] = "java/util/ArrayList";
48 constexpr char kObjectClassName[] = "Ljava/lang/Object;";
49 constexpr char kColorClassName[] = "Landroid/graphics/Color;";
50 constexpr char kColorClassNameNoSig[] = "android/graphics/Color";
51 constexpr char kColoredLabelClassName[] =
52     "Lorg/tensorflow/lite/task/vision/segmenter/ColoredLabel;";
53 constexpr char kColoredLabelClassNameNoSig[] =
54     "org/tensorflow/lite/task/vision/segmenter/ColoredLabel";
55 constexpr char kStringClassName[] = "Ljava/lang/String;";
56 constexpr int kOutputTypeCategoryMask = 0;
57 constexpr int kOutputTypeConfidenceMask = 1;
58 
59 // Creates an ImageSegmenterOptions proto based on the Java class.
ConvertToProtoOptions(JNIEnv * env,jstring display_names_locale,jint output_type,jint num_threads)60 ImageSegmenterOptions ConvertToProtoOptions(JNIEnv* env,
61                                             jstring display_names_locale,
62                                             jint output_type,
63                                             jint num_threads) {
64   ImageSegmenterOptions proto_options;
65 
66   const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr);
67   proto_options.set_display_names_locale(pchars);
68   env->ReleaseStringUTFChars(display_names_locale, pchars);
69 
70   switch (output_type) {
71     case kOutputTypeCategoryMask:
72       proto_options.set_output_type(ImageSegmenterOptions::CATEGORY_MASK);
73       break;
74     case kOutputTypeConfidenceMask:
75       proto_options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK);
76       break;
77     default:
78       // Should never happen.
79       ThrowException(env, kIllegalArgumentException,
80                      "Unsupported output type: %d", output_type);
81   }
82 
83   proto_options.set_num_threads(num_threads);
84 
85   return proto_options;
86 }
87 
ConvertToSegmentationResults(JNIEnv * env,const SegmentationResult & results,jobject jmask_buffers,jintArray jmask_shape,jobject jcolored_labels)88 void ConvertToSegmentationResults(JNIEnv* env,
89                                   const SegmentationResult& results,
90                                   jobject jmask_buffers, jintArray jmask_shape,
91                                   jobject jcolored_labels) {
92   if (results.segmentation_size() != 1) {
93     // Should never happen.
94     ThrowException(
95         env, kAssertionError,
96         "ImageSegmenter only supports one segmentation result, getting %d",
97         results.segmentation_size());
98   }
99 
100   const Segmentation& segmentation = results.segmentation(0);
101 
102   // Get the shape from the C++ Segmentation results.
103   int shape_array[2] = {segmentation.height(), segmentation.width()};
104   env->SetIntArrayRegion(jmask_shape, 0, 2, shape_array);
105 
106   // jclass, init, and add of ArrayList.
107   jclass array_list_class = env->FindClass(kArrayListClassNameNoSig);
108   jmethodID array_list_add_method =
109       env->GetMethodID(array_list_class, "add",
110                        absl::StrCat("(", kObjectClassName, ")Z").c_str());
111 
112   // Convert the masks into ByteBuffer list.
113   int num_pixels = segmentation.height() * segmentation.width();
114   if (segmentation.has_category_mask()) {
115     jbyteArray byte_array = CreateByteArray(
116         env,
117         reinterpret_cast<const jbyte*>(segmentation.category_mask().data()),
118         num_pixels * sizeof(uint8));
119     env->CallBooleanMethod(jmask_buffers, array_list_add_method, byte_array);
120     env->DeleteLocalRef(byte_array);
121   } else {
122     for (const auto& confidence_mask :
123          segmentation.confidence_masks().confidence_mask()) {
124       jbyteArray byte_array = CreateByteArray(
125           env, reinterpret_cast<const jbyte*>(confidence_mask.value().data()),
126           num_pixels * sizeof(float));
127       env->CallBooleanMethod(jmask_buffers, array_list_add_method, byte_array);
128       env->DeleteLocalRef(byte_array);
129     }
130   }
131 
132   // Convert colored labels from the C++ object to the Java object.
133   jclass color_class = env->FindClass(kColorClassNameNoSig);
134   jmethodID color_rgb_method =
135       env->GetStaticMethodID(color_class, "rgb", "(III)I");
136   jclass colored_label_class = env->FindClass(kColoredLabelClassNameNoSig);
137   jmethodID colored_label_create_method = env->GetStaticMethodID(
138       colored_label_class, "create",
139       absl::StrCat("(", kStringClassName, kStringClassName, "I)",
140                    kColoredLabelClassName)
141           .c_str());
142 
143   for (const auto& colored_label : segmentation.colored_labels()) {
144     jstring label = env->NewStringUTF(colored_label.class_name().c_str());
145     jstring display_name =
146         env->NewStringUTF(colored_label.display_name().c_str());
147     jint rgb = env->CallStaticIntMethod(color_class, color_rgb_method,
148                                         colored_label.r(), colored_label.g(),
149                                         colored_label.b());
150     jobject jcolored_label = env->CallStaticObjectMethod(
151         colored_label_class, colored_label_create_method, label, display_name,
152         rgb);
153     env->CallBooleanMethod(jcolored_labels, array_list_add_method,
154                            jcolored_label);
155 
156     env->DeleteLocalRef(label);
157     env->DeleteLocalRef(display_name);
158     env->DeleteLocalRef(jcolored_label);
159   }
160 }
161 
CreateImageClassifierFromOptions(JNIEnv * env,const ImageSegmenterOptions & options)162 jlong CreateImageClassifierFromOptions(JNIEnv* env,
163                                        const ImageSegmenterOptions& options) {
164   StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
165       ImageSegmenter::CreateFromOptions(options);
166   if (image_segmenter_or.ok()) {
167     return reinterpret_cast<jlong>(image_segmenter_or->release());
168   } else {
169     ThrowException(env, kAssertionError,
170                    "Error occurred when initializing ImageSegmenter: %s",
171                    image_segmenter_or.status().message().data());
172     return kInvalidPointer;
173   }
174 }
175 
176 extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(JNIEnv * env,jobject thiz,jlong native_handle)177 Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(
178     JNIEnv* env, jobject thiz, jlong native_handle) {
179   delete reinterpret_cast<ImageSegmenter*>(native_handle);
180 }
181 
182 // Creates an ImageSegmenter instance from the model file descriptor.
183 // file_descriptor_length and file_descriptor_offset are optional. Non-possitive
184 // values will be ignored.
185 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions(JNIEnv * env,jclass thiz,jint file_descriptor,jlong file_descriptor_length,jlong file_descriptor_offset,jstring display_names_locale,jint output_type,jint num_threads)186 Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions(
187     JNIEnv* env, jclass thiz, jint file_descriptor,
188     jlong file_descriptor_length, jlong file_descriptor_offset,
189     jstring display_names_locale, jint output_type, jint num_threads) {
190   ImageSegmenterOptions proto_options = ConvertToProtoOptions(
191       env, display_names_locale, output_type, num_threads);
192   auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata()
193                                   ->mutable_file_descriptor_meta();
194   file_descriptor_meta->set_fd(file_descriptor);
195   if (file_descriptor_length > 0) {
196     file_descriptor_meta->set_length(file_descriptor_length);
197   }
198   if (file_descriptor_offset > 0) {
199     file_descriptor_meta->set_offset(file_descriptor_offset);
200   }
201   return CreateImageClassifierFromOptions(env, proto_options);
202 }
203 
204 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer(JNIEnv * env,jclass thiz,jobject model_buffer,jstring display_names_locale,jint output_type,jint num_threads)205 Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer(
206     JNIEnv* env, jclass thiz, jobject model_buffer,
207     jstring display_names_locale, jint output_type, jint num_threads) {
208   ImageSegmenterOptions proto_options = ConvertToProtoOptions(
209       env, display_names_locale, output_type, num_threads);
210   proto_options.mutable_model_file_with_metadata()->set_file_content(
211       static_cast<char*>(env->GetDirectBufferAddress(model_buffer)),
212       static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer)));
213   return CreateImageClassifierFromOptions(env, proto_options);
214 }
215 
216 extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative(JNIEnv * env,jclass thiz,jlong native_handle,jobject jimage_byte_buffer,jint width,jint height,jobject jmask_buffers,jintArray jmask_shape,jobject jcolored_labels,jint jorientation)217 Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative(
218     JNIEnv* env, jclass thiz, jlong native_handle, jobject jimage_byte_buffer,
219     jint width, jint height, jobject jmask_buffers, jintArray jmask_shape,
220     jobject jcolored_labels, jint jorientation) {
221   auto* segmenter = reinterpret_cast<ImageSegmenter*>(native_handle);
222   absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer);
223   std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
224       reinterpret_cast<const uint8*>(image.data()),
225       FrameBuffer::Dimension{width, height},
226       ConvertToFrameBufferOrientation(env, jorientation));
227   auto results_or = segmenter->Segment(*frame_buffer);
228   if (results_or.ok()) {
229     ConvertToSegmentationResults(env, results_or.value(), jmask_buffers,
230                                  jmask_shape, jcolored_labels);
231   } else {
232     ThrowException(env, kAssertionError,
233                    "Error occurred when segmenting the image: %s",
234                    results_or.status().message().data());
235   }
236 }
237 
238 }  // namespace
239