1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string.h>
17 #include <memory>
18
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/java/src/main/native/utils_jni.h"
21 #include "tensorflow/java/src/main/native/exception_jni.h"
22 #include "tensorflow/java/src/main/native/session_jni.h"
23
24 namespace {
requireHandle(JNIEnv * env,jlong handle)25 TF_Session* requireHandle(JNIEnv* env, jlong handle) {
26 static_assert(sizeof(jlong) >= sizeof(TF_Session*),
27 "Cannot package C object pointers as a Java long");
28 if (handle == 0) {
29 throwException(env, kNullPointerException,
30 "close() has been called on the Session");
31 return nullptr;
32 }
33 return reinterpret_cast<TF_Session*>(handle);
34 }
35
36 template <class T>
resolveHandles(JNIEnv * env,const char * type,jlongArray src_array,T ** dst,jint n)37 void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array,
38 T** dst, jint n) {
39 if (env->ExceptionCheck()) return;
40 jint len = env->GetArrayLength(src_array);
41 if (len != n) {
42 throwException(env, kIllegalArgumentException, "expected %d, got %d %s", n,
43 len, type);
44 return;
45 }
46 jlong* src_start = env->GetLongArrayElements(src_array, nullptr);
47 jlong* src = src_start;
48 for (int i = 0; i < n; ++i, ++src, ++dst) {
49 if (*src == 0) {
50 throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
51 i, n);
52 break;
53 }
54 *dst = reinterpret_cast<T*>(*src);
55 }
56 env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT);
57 }
58
TF_MaybeDeleteBuffer(TF_Buffer * buf)59 void TF_MaybeDeleteBuffer(TF_Buffer* buf) {
60 if (buf == nullptr) return;
61 TF_DeleteBuffer(buf);
62 }
63
64 typedef std::unique_ptr<TF_Buffer, decltype(&TF_MaybeDeleteBuffer)>
65 unique_tf_buffer;
66
MakeUniqueBuffer(TF_Buffer * buf)67 unique_tf_buffer MakeUniqueBuffer(TF_Buffer* buf) {
68 return unique_tf_buffer(buf, TF_MaybeDeleteBuffer);
69 }
70
71 } // namespace
72
Java_org_tensorflow_Session_allocate(JNIEnv * env,jclass clazz,jlong graph_handle)73 JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate(
74 JNIEnv* env, jclass clazz, jlong graph_handle) {
75 return Java_org_tensorflow_Session_allocate2(env, clazz, graph_handle,
76 nullptr, nullptr);
77 }
78
Java_org_tensorflow_Session_allocate2(JNIEnv * env,jclass clazz,jlong graph_handle,jstring target,jbyteArray config)79 JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate2(
80 JNIEnv* env, jclass clazz, jlong graph_handle, jstring target,
81 jbyteArray config) {
82 if (graph_handle == 0) {
83 throwException(env, kNullPointerException, "Graph has been close()d");
84 return 0;
85 }
86 TF_Graph* graph = reinterpret_cast<TF_Graph*>(graph_handle);
87 TF_Status* status = TF_NewStatus();
88 TF_SessionOptions* opts = TF_NewSessionOptions();
89 jbyte* cconfig = nullptr;
90 if (config != nullptr) {
91 cconfig = env->GetByteArrayElements(config, nullptr);
92 TF_SetConfig(opts, cconfig,
93 static_cast<size_t>(env->GetArrayLength(config)), status);
94 if (!throwExceptionIfNotOK(env, status)) {
95 env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
96 TF_DeleteSessionOptions(opts);
97 TF_DeleteStatus(status);
98 return 0;
99 }
100 }
101 const char* ctarget = nullptr;
102 if (target != nullptr) {
103 ctarget = env->GetStringUTFChars(target, nullptr);
104 }
105 TF_Session* session = TF_NewSession(graph, opts, status);
106 if (config != nullptr) {
107 env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
108 }
109 if (target != nullptr) {
110 env->ReleaseStringUTFChars(target, ctarget);
111 }
112 TF_DeleteSessionOptions(opts);
113 bool ok = throwExceptionIfNotOK(env, status);
114 TF_DeleteStatus(status);
115
116 return ok ? reinterpret_cast<jlong>(session) : 0;
117 }
118
Java_org_tensorflow_Session_delete(JNIEnv * env,jclass clazz,jlong handle)119 JNIEXPORT void JNICALL Java_org_tensorflow_Session_delete(JNIEnv* env,
120 jclass clazz,
121 jlong handle) {
122 TF_Session* session = requireHandle(env, handle);
123 if (session == nullptr) return;
124 TF_Status* status = TF_NewStatus();
125 TF_CloseSession(session, status);
126 // Result of close is ignored, delete anyway.
127 TF_DeleteSession(session, status);
128 throwExceptionIfNotOK(env, status);
129 TF_DeleteStatus(status);
130 }
131
Java_org_tensorflow_Session_run(JNIEnv * env,jclass clazz,jlong handle,jbyteArray jrun_options,jlongArray input_tensor_handles,jlongArray input_op_handles,jintArray input_op_indices,jlongArray output_op_handles,jintArray output_op_indices,jlongArray target_op_handles,jboolean want_run_metadata,jlongArray output_tensor_handles)132 JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
133 JNIEnv* env, jclass clazz, jlong handle, jbyteArray jrun_options,
134 jlongArray input_tensor_handles, jlongArray input_op_handles,
135 jintArray input_op_indices, jlongArray output_op_handles,
136 jintArray output_op_indices, jlongArray target_op_handles,
137 jboolean want_run_metadata, jlongArray output_tensor_handles) {
138 TF_Session* session = requireHandle(env, handle);
139 if (session == nullptr) return nullptr;
140
141 const jint ninputs = env->GetArrayLength(input_tensor_handles);
142 const jint noutputs = env->GetArrayLength(output_tensor_handles);
143 const jint ntargets = env->GetArrayLength(target_op_handles);
144
145 std::unique_ptr<TF_Output[]> inputs(new TF_Output[ninputs]);
146 std::unique_ptr<TF_Tensor* []> input_values(new TF_Tensor*[ninputs]);
147 std::unique_ptr<TF_Output[]> outputs(new TF_Output[noutputs]);
148 std::unique_ptr<TF_Tensor* []> output_values(new TF_Tensor*[noutputs]);
149 std::unique_ptr<TF_Operation* []> targets(new TF_Operation*[ntargets]);
150 unique_tf_buffer run_metadata(
151 MakeUniqueBuffer(want_run_metadata ? TF_NewBuffer() : nullptr));
152
153 resolveHandles(env, "input Tensors", input_tensor_handles, input_values.get(),
154 ninputs);
155 resolveOutputs(env, "input", input_op_handles, input_op_indices, inputs.get(),
156 ninputs);
157 resolveOutputs(env, "output", output_op_handles, output_op_indices,
158 outputs.get(), noutputs);
159 resolveHandles(env, "target Operations", target_op_handles, targets.get(),
160 ntargets);
161 if (env->ExceptionCheck()) return nullptr;
162
163 TF_Status* status = TF_NewStatus();
164
165 unique_tf_buffer run_options(MakeUniqueBuffer(nullptr));
166 jbyte* jrun_options_data = nullptr;
167 if (jrun_options != nullptr) {
168 size_t sz = env->GetArrayLength(jrun_options);
169 if (sz > 0) {
170 jrun_options_data = env->GetByteArrayElements(jrun_options, nullptr);
171 run_options.reset(
172 TF_NewBufferFromString(static_cast<void*>(jrun_options_data), sz));
173 }
174 }
175
176 TF_SessionRun(session, run_options.get(), inputs.get(), input_values.get(),
177 static_cast<int>(ninputs), outputs.get(), output_values.get(),
178 static_cast<int>(noutputs), targets.get(),
179 static_cast<int>(ntargets), run_metadata.get(), status);
180
181 if (jrun_options_data != nullptr) {
182 env->ReleaseByteArrayElements(jrun_options, jrun_options_data, JNI_ABORT);
183 }
184
185 if (!throwExceptionIfNotOK(env, status)) {
186 TF_DeleteStatus(status);
187 return nullptr;
188 }
189 jlong* t = env->GetLongArrayElements(output_tensor_handles, nullptr);
190 for (int i = 0; i < noutputs; ++i) {
191 t[i] = reinterpret_cast<jlong>(output_values[i]);
192 }
193 env->ReleaseLongArrayElements(output_tensor_handles, t, 0);
194
195 jbyteArray ret = nullptr;
196 if (run_metadata != nullptr) {
197 ret = env->NewByteArray(run_metadata->length);
198 env->SetByteArrayRegion(ret, 0, run_metadata->length,
199 reinterpret_cast<const jbyte*>(run_metadata->data));
200 }
201 TF_DeleteStatus(status);
202 return ret;
203 }
204