1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/java/src/main/native/eager_operation_builder_jni.h"
17
18 #include <cstring>
19 #include <memory>
20 #include <set>
21
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/java/src/main/native/exception_jni.h"
24
25 // This value should be >= to the maximum number of outputs in any op
26 #define MAX_OUTPUTS_PER_OP 8
27
28 namespace {
29
requireOp(JNIEnv * env,jlong handle)30 TFE_Op* requireOp(JNIEnv* env, jlong handle) {
31 if (handle == 0) {
32 throwException(env, kIllegalStateException,
33 "Operation has already been built");
34 return nullptr;
35 }
36 return reinterpret_cast<TFE_Op*>(handle);
37 }
38
requireContext(JNIEnv * env,jlong handle)39 TFE_Context* requireContext(JNIEnv* env, jlong handle) {
40 if (handle == 0) {
41 throwException(env, kIllegalStateException, "Context has been deleted");
42 return nullptr;
43 }
44 return reinterpret_cast<TFE_Context*>(handle);
45 }
46
requireTensor(JNIEnv * env,jlong handle)47 TF_Tensor* requireTensor(JNIEnv* env, jlong handle) {
48 if (handle == 0) {
49 throwException(env, kIllegalStateException,
50 "close() has been called on the Tensor");
51 return nullptr;
52 }
53 return reinterpret_cast<TF_Tensor*>(handle);
54 }
55
requireTensorHandle(JNIEnv * env,jlong handle)56 TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) {
57 if (handle == 0) {
58 throwException(env, kIllegalStateException,
59 "Tensor handle has been deleted");
60 return nullptr;
61 }
62 return reinterpret_cast<TFE_TensorHandle*>(handle);
63 }
64
65 } // namespace
66
Java_org_tensorflow_EagerOperationBuilder_allocate(JNIEnv * env,jclass clazz,jlong context_handle,jstring name)67 JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate(
68 JNIEnv* env, jclass clazz, jlong context_handle, jstring name) {
69 TFE_Context* context = requireContext(env, context_handle);
70 if (context == nullptr) return 0;
71 const char* op_or_function_name = env->GetStringUTFChars(name, nullptr);
72 TF_Status* status = TF_NewStatus();
73 TFE_Op* op = TFE_NewOp(context, op_or_function_name, status);
74 env->ReleaseStringUTFChars(name, op_or_function_name);
75 if (!throwExceptionIfNotOK(env, status)) {
76 TF_DeleteStatus(status);
77 return 0;
78 }
79 TF_DeleteStatus(status);
80 static_assert(sizeof(jlong) >= sizeof(TFE_Op*),
81 "Cannot represent a C TFE_Op as a Java long");
82 return reinterpret_cast<jlong>(op);
83 }
84
Java_org_tensorflow_EagerOperationBuilder_delete(JNIEnv * env,jclass clazz,jlong op_handle)85 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_delete(
86 JNIEnv* env, jclass clazz, jlong op_handle) {
87 if (op_handle == 0) return;
88 TFE_DeleteOp(reinterpret_cast<TFE_Op*>(op_handle));
89 }
90
Java_org_tensorflow_EagerOperationBuilder_execute(JNIEnv * env,jclass clazz,jlong op_handle)91 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute(
92 JNIEnv* env, jclass clazz, jlong op_handle) {
93 TFE_Op* op = requireOp(env, op_handle);
94 if (op == nullptr) return 0;
95 int num_retvals = MAX_OUTPUTS_PER_OP;
96 std::unique_ptr<TFE_TensorHandle*[]> retvals(
97 new TFE_TensorHandle*[num_retvals]);
98 TF_Status* status = TF_NewStatus();
99 TFE_Execute(op, retvals.get(), &num_retvals, status);
100 if (!throwExceptionIfNotOK(env, status)) {
101 TF_DeleteStatus(status);
102 return nullptr;
103 }
104 TF_DeleteStatus(status);
105 jlongArray rethandles = env->NewLongArray(num_retvals);
106 if (num_retvals > 0) {
107 jlong* retval = env->GetLongArrayElements(rethandles, nullptr);
108 for (int i = 0; i < num_retvals; ++i) {
109 retval[i] = reinterpret_cast<jlong>(retvals[i]);
110 }
111 env->ReleaseLongArrayElements(rethandles, retval, 0);
112 }
113 return rethandles;
114 }
115
Java_org_tensorflow_EagerOperationBuilder_setDevice(JNIEnv * env,jclass clazz,jlong op_handle,jstring device_name)116 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice(
117 JNIEnv* env, jclass clazz, jlong op_handle, jstring device_name) {
118 TFE_Op* op = requireOp(env, op_handle);
119 if (op == nullptr) return;
120 const char* cname = env->GetStringUTFChars(device_name, nullptr);
121 TF_Status* status = TF_NewStatus();
122 TFE_OpSetDevice(op, cname, status);
123 throwExceptionIfNotOK(env, status);
124 TF_DeleteStatus(status);
125 env->ReleaseStringUTFChars(device_name, cname);
126 }
127
Java_org_tensorflow_EagerOperationBuilder_addInput(JNIEnv * env,jclass clazz,jlong op_handle,jlong input_handle)128 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput(
129 JNIEnv* env, jclass clazz, jlong op_handle, jlong input_handle) {
130 TFE_Op* op = requireOp(env, op_handle);
131 if (op == nullptr) return;
132 TFE_TensorHandle* tensor_handle = requireTensorHandle(env, input_handle);
133 if (tensor_handle == nullptr) return;
134 TF_Status* status = TF_NewStatus();
135 TFE_OpAddInput(op, tensor_handle, status);
136 throwExceptionIfNotOK(env, status);
137 TF_DeleteStatus(status);
138 }
139
Java_org_tensorflow_EagerOperationBuilder_addInputList(JNIEnv * env,jclass clazz,jlong op_handle,jlongArray input_handles)140 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList(
141 JNIEnv* env, jclass clazz, jlong op_handle, jlongArray input_handles) {
142 TFE_Op* op = requireOp(env, op_handle);
143 if (op == nullptr) return;
144 jlong* cinput_handles = env->GetLongArrayElements(input_handles, nullptr);
145 size_t num_inputs = static_cast<size_t>(env->GetArrayLength(input_handles));
146 std::unique_ptr<TFE_TensorHandle*[]> tensor_handles(
147 new TFE_TensorHandle*[num_inputs]);
148 for (int i = 0; i < num_inputs; ++i) {
149 tensor_handles[i] = requireTensorHandle(env, cinput_handles[i]);
150 if (tensor_handles[i] == nullptr) {
151 env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT);
152 return;
153 }
154 }
155 env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT);
156 TF_Status* status = TF_NewStatus();
157 TFE_OpAddInputList(op, tensor_handles.get(), num_inputs, status);
158 throwExceptionIfNotOK(env, status);
159 TF_DeleteStatus(status);
160 }
161
Java_org_tensorflow_EagerOperationBuilder_setAttrString(JNIEnv * env,jclass clazz,jlong op_handle,jstring attr_name,jbyteArray value)162 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString(
163 JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
164 jbyteArray value) {
165 static_assert(sizeof(jbyte) == 1,
166 "Require Java byte to be represented as a single byte");
167 TFE_Op* op = requireOp(env, op_handle);
168 if (op == nullptr) return;
169 const char* cname = env->GetStringUTFChars(attr_name, nullptr);
170 jbyte* cvalue = env->GetByteArrayElements(value, nullptr);
171 TFE_OpSetAttrString(op, cname, cvalue, env->GetArrayLength(value));
172 env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT);
173 env->ReleaseStringUTFChars(attr_name, cname);
174 }
175
176 JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(JNIEnv * env,jclass object,jlong op_handle,jstring attr_name,jobjectArray values)177 Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(
178 JNIEnv* env, jclass object, jlong op_handle, jstring attr_name,
179 jobjectArray values) {
180 TFE_Op* op = requireOp(env, op_handle);
181 if (op == nullptr) return;
182 const char* cname = env->GetStringUTFChars(attr_name, nullptr);
183 int num_values = env->GetArrayLength(values);
184 static_assert(sizeof(jbyte) == 1,
185 "Require Java byte to be represented as a single byte");
186 std::unique_ptr<jbyteArray[]> jarrays(new jbyteArray[num_values]);
187 std::unique_ptr<jbyte*[]> jvalues(new jbyte*[num_values]);
188 std::unique_ptr<void*[]> cvalues(new void*[num_values]);
189 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
190
191 for (int i = 0; i < num_values; ++i) {
192 jbyteArray v =
193 static_cast<jbyteArray>(env->GetObjectArrayElement(values, i));
194 jarrays[i] = v;
195 jvalues[i] = env->GetByteArrayElements(v, nullptr);
196 cvalues[i] = jvalues[i];
197 lengths[i] = static_cast<size_t>(env->GetArrayLength(v));
198 }
199 TFE_OpSetAttrStringList(op, cname, cvalues.get(), lengths.get(), num_values);
200 for (int i = 0; i < num_values; ++i) {
201 env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT);
202 }
203 env->ReleaseStringUTFChars(attr_name, cname);
204 }
205
206 #define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
207 JNIEXPORT void JNICALL \
208 Java_org_tensorflow_EagerOperationBuilder_setAttr##name( \
209 JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \
210 jtype value) { \
211 static_assert( \
212 sizeof(ctype) >= sizeof(jtype), \
213 "Information loss when converting between Java and C types"); \
214 TFE_Op* op = requireOp(env, op_handle); \
215 if (op == nullptr) return; \
216 const char* cname = env->GetStringUTFChars(attr_name, nullptr); \
217 TFE_OpSetAttr##name(op, cname, static_cast<ctype>(value)); \
218 env->ReleaseStringUTFChars(attr_name, cname); \
219 }
220
221 #define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \
222 JNIEXPORT void JNICALL \
223 Java_org_tensorflow_EagerOperationBuilder_setAttr##name##List( \
224 JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \
225 jtype##Array value) { \
226 TFE_Op* op = requireOp(env, op_handle); \
227 if (op == nullptr) return; \
228 const char* cname = env->GetStringUTFChars(attr_name, nullptr); \
229 /* Make a copy of the array to paper over any differences */ \
230 /* in byte representations of the jtype and ctype */ \
231 /* For example, jint vs TF_DataType. */ \
232 /* If this copy turns out to be a problem in practice */ \
233 /* can avoid it for many types. */ \
234 const int n = env->GetArrayLength(value); \
235 std::unique_ptr<ctype[]> cvalue(new ctype[n]); \
236 jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \
237 for (int i = 0; i < n; ++i) { \
238 cvalue[i] = static_cast<ctype>(elems[i]); \
239 } \
240 TFE_OpSetAttr##name##List(op, cname, cvalue.get(), n); \
241 env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \
242 env->ReleaseStringUTFChars(attr_name, cname); \
243 }
244
245 #define DEFINE_SET_ATTR(name, jname, jtype, ctype) \
246 DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
247 DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)
248
249 DEFINE_SET_ATTR(Int, Long, jlong, int64_t);
250 DEFINE_SET_ATTR(Float, Float, jfloat, float);
251 DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char);
252 DEFINE_SET_ATTR(Type, Int, jint, TF_DataType);
253 #undef DEFINE_SET_ATTR
254 #undef DEFINE_SET_ATTR_LIST
255 #undef DEFINE_SET_ATTR_SCALAR
256
Java_org_tensorflow_EagerOperationBuilder_setAttrTensor(JNIEnv * env,jclass clazz,jlong handle,jstring attr_name,jlong tensor_handle)257 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor(
258 JNIEnv* env, jclass clazz, jlong handle, jstring attr_name,
259 jlong tensor_handle) {
260 TFE_Op* op = requireOp(env, handle);
261 if (op == nullptr) return;
262 TF_Tensor* t = requireTensor(env, tensor_handle);
263 if (t == nullptr) return;
264 const char* cname = env->GetStringUTFChars(attr_name, nullptr);
265 TF_Status* status = TF_NewStatus();
266 TFE_OpSetAttrTensor(op, cname, t, status);
267 throwExceptionIfNotOK(env, status);
268 TF_DeleteStatus(status);
269 env->ReleaseStringUTFChars(attr_name, cname);
270 }
271
Java_org_tensorflow_EagerOperationBuilder_setAttrShape(JNIEnv * env,jclass clazz,jlong op_handle,jstring attr_name,jlongArray shape,jint num_dims)272 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape(
273 JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
274 jlongArray shape, jint num_dims) {
275 TFE_Op* op = requireOp(env, op_handle);
276 if (op == nullptr) return;
277 std::unique_ptr<int64_t[]> cvalue;
278 // num_dims and env->GetArrayLength(shape) are assumed to be consistent.
279 // i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
280 if (num_dims > 0) {
281 cvalue.reset(new int64_t[num_dims]);
282 jlong* elems = env->GetLongArrayElements(shape, nullptr);
283 for (int i = 0; i < num_dims; ++i) {
284 cvalue[i] = static_cast<int64_t>(elems[i]);
285 }
286 env->ReleaseLongArrayElements(shape, elems, JNI_ABORT);
287 }
288 const char* cname = env->GetStringUTFChars(attr_name, nullptr);
289 TF_Status* status = TF_NewStatus();
290 TFE_OpSetAttrShape(op, cname, cvalue.get(), static_cast<int>(num_dims),
291 status);
292 throwExceptionIfNotOK(env, status);
293 TF_DeleteStatus(status);
294 env->ReleaseStringUTFChars(attr_name, cname);
295 }
296
297 JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(JNIEnv * env,jclass clazz,jlong op_handle,jstring attr_name,jlongArray shapes,jintArray num_dims)298 Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(
299 JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
300 jlongArray shapes, jintArray num_dims) {
301 TFE_Op* op = requireOp(env, op_handle);
302 if (op == nullptr) return;
303 std::unique_ptr<int64_t[]> cshapes;
304 std::unique_ptr<const int64_t*[]> cdims;
305 std::unique_ptr<int[]> cnum_dims;
306 const int num_dims_length = env->GetArrayLength(num_dims);
307 if (num_dims_length > 0) {
308 const int shapes_length = env->GetArrayLength(shapes);
309 cshapes.reset(new int64_t[shapes_length]);
310 cdims.reset(new const int64_t*[num_dims_length]);
311 cnum_dims.reset(new int[num_dims_length]);
312 jlong* shapes_elems =
313 static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr));
314 std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
315 env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
316 int64_t* cshapes_ptr = cshapes.get();
317 jint* num_dims_elems =
318 static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr));
319 for (int i = 0; i < num_dims_length; ++i) {
320 cnum_dims[i] = static_cast<int>(num_dims_elems[i]);
321 cdims[i] = cshapes_ptr;
322 if (cnum_dims[i] > 0) {
323 cshapes_ptr += cnum_dims[i];
324 }
325 }
326 env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT);
327 }
328 const char* cname = env->GetStringUTFChars(attr_name, nullptr);
329 TF_Status* status = TF_NewStatus();
330 TFE_OpSetAttrShapeList(op, cname, cdims.get(), cnum_dims.get(),
331 num_dims_length, status);
332 throwExceptionIfNotOK(env, status);
333 TF_DeleteStatus(status);
334 env->ReleaseStringUTFChars(attr_name, cname);
335 }
336