1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <jni.h>
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow_lite_support/cc/port/statusor.h"
23 #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
24 #include "tensorflow_lite_support/cc/task/vision/image_segmenter.h"
25 #include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h"
26 #include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h"
27 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
28 #include "tensorflow_lite_support/cc/utils/jni_utils.h"
29 #include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h"
30
31 namespace {
32
33 using ::tflite::support::StatusOr;
34 using ::tflite::support::utils::CreateByteArray;
35 using ::tflite::support::utils::GetMappedFileBuffer;
36 using ::tflite::support::utils::kAssertionError;
37 using ::tflite::support::utils::kIllegalArgumentException;
38 using ::tflite::support::utils::kInvalidPointer;
39 using ::tflite::support::utils::ThrowException;
40 using ::tflite::task::vision::ConvertToFrameBufferOrientation;
41 using ::tflite::task::vision::FrameBuffer;
42 using ::tflite::task::vision::ImageSegmenter;
43 using ::tflite::task::vision::ImageSegmenterOptions;
44 using ::tflite::task::vision::Segmentation;
45 using ::tflite::task::vision::SegmentationResult;
46
47 constexpr char kArrayListClassNameNoSig[] = "java/util/ArrayList";
48 constexpr char kObjectClassName[] = "Ljava/lang/Object;";
49 constexpr char kColorClassName[] = "Landroid/graphics/Color;";
50 constexpr char kColorClassNameNoSig[] = "android/graphics/Color";
51 constexpr char kColoredLabelClassName[] =
52 "Lorg/tensorflow/lite/task/vision/segmenter/ColoredLabel;";
53 constexpr char kColoredLabelClassNameNoSig[] =
54 "org/tensorflow/lite/task/vision/segmenter/ColoredLabel";
55 constexpr char kStringClassName[] = "Ljava/lang/String;";
56 constexpr int kOutputTypeCategoryMask = 0;
57 constexpr int kOutputTypeConfidenceMask = 1;
58
59 // Creates an ImageSegmenterOptions proto based on the Java class.
ConvertToProtoOptions(JNIEnv * env,jstring display_names_locale,jint output_type,jint num_threads)60 ImageSegmenterOptions ConvertToProtoOptions(JNIEnv* env,
61 jstring display_names_locale,
62 jint output_type,
63 jint num_threads) {
64 ImageSegmenterOptions proto_options;
65
66 const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr);
67 proto_options.set_display_names_locale(pchars);
68 env->ReleaseStringUTFChars(display_names_locale, pchars);
69
70 switch (output_type) {
71 case kOutputTypeCategoryMask:
72 proto_options.set_output_type(ImageSegmenterOptions::CATEGORY_MASK);
73 break;
74 case kOutputTypeConfidenceMask:
75 proto_options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK);
76 break;
77 default:
78 // Should never happen.
79 ThrowException(env, kIllegalArgumentException,
80 "Unsupported output type: %d", output_type);
81 }
82
83 proto_options.set_num_threads(num_threads);
84
85 return proto_options;
86 }
87
ConvertToSegmentationResults(JNIEnv * env,const SegmentationResult & results,jobject jmask_buffers,jintArray jmask_shape,jobject jcolored_labels)88 void ConvertToSegmentationResults(JNIEnv* env,
89 const SegmentationResult& results,
90 jobject jmask_buffers, jintArray jmask_shape,
91 jobject jcolored_labels) {
92 if (results.segmentation_size() != 1) {
93 // Should never happen.
94 ThrowException(
95 env, kAssertionError,
96 "ImageSegmenter only supports one segmentation result, getting %d",
97 results.segmentation_size());
98 }
99
100 const Segmentation& segmentation = results.segmentation(0);
101
102 // Get the shape from the C++ Segmentation results.
103 int shape_array[2] = {segmentation.height(), segmentation.width()};
104 env->SetIntArrayRegion(jmask_shape, 0, 2, shape_array);
105
106 // jclass, init, and add of ArrayList.
107 jclass array_list_class = env->FindClass(kArrayListClassNameNoSig);
108 jmethodID array_list_add_method =
109 env->GetMethodID(array_list_class, "add",
110 absl::StrCat("(", kObjectClassName, ")Z").c_str());
111
112 // Convert the masks into ByteBuffer list.
113 int num_pixels = segmentation.height() * segmentation.width();
114 if (segmentation.has_category_mask()) {
115 jbyteArray byte_array = CreateByteArray(
116 env,
117 reinterpret_cast<const jbyte*>(segmentation.category_mask().data()),
118 num_pixels * sizeof(uint8));
119 env->CallBooleanMethod(jmask_buffers, array_list_add_method, byte_array);
120 env->DeleteLocalRef(byte_array);
121 } else {
122 for (const auto& confidence_mask :
123 segmentation.confidence_masks().confidence_mask()) {
124 jbyteArray byte_array = CreateByteArray(
125 env, reinterpret_cast<const jbyte*>(confidence_mask.value().data()),
126 num_pixels * sizeof(float));
127 env->CallBooleanMethod(jmask_buffers, array_list_add_method, byte_array);
128 env->DeleteLocalRef(byte_array);
129 }
130 }
131
132 // Convert colored labels from the C++ object to the Java object.
133 jclass color_class = env->FindClass(kColorClassNameNoSig);
134 jmethodID color_rgb_method =
135 env->GetStaticMethodID(color_class, "rgb", "(III)I");
136 jclass colored_label_class = env->FindClass(kColoredLabelClassNameNoSig);
137 jmethodID colored_label_create_method = env->GetStaticMethodID(
138 colored_label_class, "create",
139 absl::StrCat("(", kStringClassName, kStringClassName, "I)",
140 kColoredLabelClassName)
141 .c_str());
142
143 for (const auto& colored_label : segmentation.colored_labels()) {
144 jstring label = env->NewStringUTF(colored_label.class_name().c_str());
145 jstring display_name =
146 env->NewStringUTF(colored_label.display_name().c_str());
147 jint rgb = env->CallStaticIntMethod(color_class, color_rgb_method,
148 colored_label.r(), colored_label.g(),
149 colored_label.b());
150 jobject jcolored_label = env->CallStaticObjectMethod(
151 colored_label_class, colored_label_create_method, label, display_name,
152 rgb);
153 env->CallBooleanMethod(jcolored_labels, array_list_add_method,
154 jcolored_label);
155
156 env->DeleteLocalRef(label);
157 env->DeleteLocalRef(display_name);
158 env->DeleteLocalRef(jcolored_label);
159 }
160 }
161
CreateImageClassifierFromOptions(JNIEnv * env,const ImageSegmenterOptions & options)162 jlong CreateImageClassifierFromOptions(JNIEnv* env,
163 const ImageSegmenterOptions& options) {
164 StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
165 ImageSegmenter::CreateFromOptions(options);
166 if (image_segmenter_or.ok()) {
167 return reinterpret_cast<jlong>(image_segmenter_or->release());
168 } else {
169 ThrowException(env, kAssertionError,
170 "Error occurred when initializing ImageSegmenter: %s",
171 image_segmenter_or.status().message().data());
172 return kInvalidPointer;
173 }
174 }
175
176 extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(JNIEnv * env,jobject thiz,jlong native_handle)177 Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(
178 JNIEnv* env, jobject thiz, jlong native_handle) {
179 delete reinterpret_cast<ImageSegmenter*>(native_handle);
180 }
181
182 // Creates an ImageSegmenter instance from the model file descriptor.
183 // file_descriptor_length and file_descriptor_offset are optional. Non-possitive
184 // values will be ignored.
185 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions(JNIEnv * env,jclass thiz,jint file_descriptor,jlong file_descriptor_length,jlong file_descriptor_offset,jstring display_names_locale,jint output_type,jint num_threads)186 Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions(
187 JNIEnv* env, jclass thiz, jint file_descriptor,
188 jlong file_descriptor_length, jlong file_descriptor_offset,
189 jstring display_names_locale, jint output_type, jint num_threads) {
190 ImageSegmenterOptions proto_options = ConvertToProtoOptions(
191 env, display_names_locale, output_type, num_threads);
192 auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata()
193 ->mutable_file_descriptor_meta();
194 file_descriptor_meta->set_fd(file_descriptor);
195 if (file_descriptor_length > 0) {
196 file_descriptor_meta->set_length(file_descriptor_length);
197 }
198 if (file_descriptor_offset > 0) {
199 file_descriptor_meta->set_offset(file_descriptor_offset);
200 }
201 return CreateImageClassifierFromOptions(env, proto_options);
202 }
203
204 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer(JNIEnv * env,jclass thiz,jobject model_buffer,jstring display_names_locale,jint output_type,jint num_threads)205 Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer(
206 JNIEnv* env, jclass thiz, jobject model_buffer,
207 jstring display_names_locale, jint output_type, jint num_threads) {
208 ImageSegmenterOptions proto_options = ConvertToProtoOptions(
209 env, display_names_locale, output_type, num_threads);
210 proto_options.mutable_model_file_with_metadata()->set_file_content(
211 static_cast<char*>(env->GetDirectBufferAddress(model_buffer)),
212 static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer)));
213 return CreateImageClassifierFromOptions(env, proto_options);
214 }
215
216 extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative(JNIEnv * env,jclass thiz,jlong native_handle,jobject jimage_byte_buffer,jint width,jint height,jobject jmask_buffers,jintArray jmask_shape,jobject jcolored_labels,jint jorientation)217 Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative(
218 JNIEnv* env, jclass thiz, jlong native_handle, jobject jimage_byte_buffer,
219 jint width, jint height, jobject jmask_buffers, jintArray jmask_shape,
220 jobject jcolored_labels, jint jorientation) {
221 auto* segmenter = reinterpret_cast<ImageSegmenter*>(native_handle);
222 absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer);
223 std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
224 reinterpret_cast<const uint8*>(image.data()),
225 FrameBuffer::Dimension{width, height},
226 ConvertToFrameBufferOrientation(env, jorientation));
227 auto results_or = segmenter->Segment(*frame_buffer);
228 if (results_or.ok()) {
229 ConvertToSegmentationResults(env, results_or.value(), jmask_buffers,
230 jmask_shape, jcolored_labels);
231 } else {
232 ThrowException(env, kAssertionError,
233 "Error occurred when segmenting the image: %s",
234 results_or.status().message().data());
235 }
236 }
237
238 } // namespace
239