• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/java/src/main/native/eager_operation_builder_jni.h"
17 
18 #include <cstring>
19 #include <memory>
20 #include <set>
21 
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/java/src/main/native/exception_jni.h"
24 
25 // This value should be >= to the maximum number of outputs in any op
26 #define MAX_OUTPUTS_PER_OP 8
27 
28 namespace {
29 
requireOp(JNIEnv * env,jlong handle)30 TFE_Op* requireOp(JNIEnv* env, jlong handle) {
31   if (handle == 0) {
32     throwException(env, kIllegalStateException,
33                    "Operation has already been built");
34     return nullptr;
35   }
36   return reinterpret_cast<TFE_Op*>(handle);
37 }
38 
requireContext(JNIEnv * env,jlong handle)39 TFE_Context* requireContext(JNIEnv* env, jlong handle) {
40   if (handle == 0) {
41     throwException(env, kIllegalStateException, "Context has been deleted");
42     return nullptr;
43   }
44   return reinterpret_cast<TFE_Context*>(handle);
45 }
46 
requireTensor(JNIEnv * env,jlong handle)47 TF_Tensor* requireTensor(JNIEnv* env, jlong handle) {
48   if (handle == 0) {
49     throwException(env, kIllegalStateException,
50                    "close() has been called on the Tensor");
51     return nullptr;
52   }
53   return reinterpret_cast<TF_Tensor*>(handle);
54 }
55 
requireTensorHandle(JNIEnv * env,jlong handle)56 TFE_TensorHandle* requireTensorHandle(JNIEnv* env, jlong handle) {
57   if (handle == 0) {
58     throwException(env, kIllegalStateException,
59                    "Tensor handle has been deleted");
60     return nullptr;
61   }
62   return reinterpret_cast<TFE_TensorHandle*>(handle);
63 }
64 
65 }  // namespace
66 
Java_org_tensorflow_EagerOperationBuilder_allocate(JNIEnv * env,jclass clazz,jlong context_handle,jstring name)67 JNIEXPORT jlong JNICALL Java_org_tensorflow_EagerOperationBuilder_allocate(
68     JNIEnv* env, jclass clazz, jlong context_handle, jstring name) {
69   TFE_Context* context = requireContext(env, context_handle);
70   if (context == nullptr) return 0;
71   const char* op_or_function_name = env->GetStringUTFChars(name, nullptr);
72   TF_Status* status = TF_NewStatus();
73   TFE_Op* op = TFE_NewOp(context, op_or_function_name, status);
74   env->ReleaseStringUTFChars(name, op_or_function_name);
75   if (!throwExceptionIfNotOK(env, status)) {
76     TF_DeleteStatus(status);
77     return 0;
78   }
79   TF_DeleteStatus(status);
80   static_assert(sizeof(jlong) >= sizeof(TFE_Op*),
81                 "Cannot represent a C TFE_Op as a Java long");
82   return reinterpret_cast<jlong>(op);
83 }
84 
Java_org_tensorflow_EagerOperationBuilder_delete(JNIEnv * env,jclass clazz,jlong op_handle)85 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_delete(
86     JNIEnv* env, jclass clazz, jlong op_handle) {
87   if (op_handle == 0) return;
88   TFE_DeleteOp(reinterpret_cast<TFE_Op*>(op_handle));
89 }
90 
Java_org_tensorflow_EagerOperationBuilder_execute(JNIEnv * env,jclass clazz,jlong op_handle)91 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_EagerOperationBuilder_execute(
92     JNIEnv* env, jclass clazz, jlong op_handle) {
93   TFE_Op* op = requireOp(env, op_handle);
94   if (op == nullptr) return 0;
95   int num_retvals = MAX_OUTPUTS_PER_OP;
96   std::unique_ptr<TFE_TensorHandle*[]> retvals(
97       new TFE_TensorHandle*[num_retvals]);
98   TF_Status* status = TF_NewStatus();
99   TFE_Execute(op, retvals.get(), &num_retvals, status);
100   if (!throwExceptionIfNotOK(env, status)) {
101     TF_DeleteStatus(status);
102     return nullptr;
103   }
104   TF_DeleteStatus(status);
105   jlongArray rethandles = env->NewLongArray(num_retvals);
106   if (num_retvals > 0) {
107     jlong* retval = env->GetLongArrayElements(rethandles, nullptr);
108     for (int i = 0; i < num_retvals; ++i) {
109       retval[i] = reinterpret_cast<jlong>(retvals[i]);
110     }
111     env->ReleaseLongArrayElements(rethandles, retval, 0);
112   }
113   return rethandles;
114 }
115 
Java_org_tensorflow_EagerOperationBuilder_setDevice(JNIEnv * env,jclass clazz,jlong op_handle,jstring device_name)116 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setDevice(
117     JNIEnv* env, jclass clazz, jlong op_handle, jstring device_name) {
118   TFE_Op* op = requireOp(env, op_handle);
119   if (op == nullptr) return;
120   const char* cname = env->GetStringUTFChars(device_name, nullptr);
121   TF_Status* status = TF_NewStatus();
122   TFE_OpSetDevice(op, cname, status);
123   throwExceptionIfNotOK(env, status);
124   TF_DeleteStatus(status);
125   env->ReleaseStringUTFChars(device_name, cname);
126 }
127 
Java_org_tensorflow_EagerOperationBuilder_addInput(JNIEnv * env,jclass clazz,jlong op_handle,jlong input_handle)128 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInput(
129     JNIEnv* env, jclass clazz, jlong op_handle, jlong input_handle) {
130   TFE_Op* op = requireOp(env, op_handle);
131   if (op == nullptr) return;
132   TFE_TensorHandle* tensor_handle = requireTensorHandle(env, input_handle);
133   if (tensor_handle == nullptr) return;
134   TF_Status* status = TF_NewStatus();
135   TFE_OpAddInput(op, tensor_handle, status);
136   throwExceptionIfNotOK(env, status);
137   TF_DeleteStatus(status);
138 }
139 
Java_org_tensorflow_EagerOperationBuilder_addInputList(JNIEnv * env,jclass clazz,jlong op_handle,jlongArray input_handles)140 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_addInputList(
141     JNIEnv* env, jclass clazz, jlong op_handle, jlongArray input_handles) {
142   TFE_Op* op = requireOp(env, op_handle);
143   if (op == nullptr) return;
144   jlong* cinput_handles = env->GetLongArrayElements(input_handles, nullptr);
145   size_t num_inputs = static_cast<size_t>(env->GetArrayLength(input_handles));
146   std::unique_ptr<TFE_TensorHandle*[]> tensor_handles(
147       new TFE_TensorHandle*[num_inputs]);
148   for (int i = 0; i < num_inputs; ++i) {
149     tensor_handles[i] = requireTensorHandle(env, cinput_handles[i]);
150     if (tensor_handles[i] == nullptr) {
151       env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT);
152       return;
153     }
154   }
155   env->ReleaseLongArrayElements(input_handles, cinput_handles, JNI_ABORT);
156   TF_Status* status = TF_NewStatus();
157   TFE_OpAddInputList(op, tensor_handles.get(), num_inputs, status);
158   throwExceptionIfNotOK(env, status);
159   TF_DeleteStatus(status);
160 }
161 
Java_org_tensorflow_EagerOperationBuilder_setAttrString(JNIEnv * env,jclass clazz,jlong op_handle,jstring attr_name,jbyteArray value)162 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrString(
163     JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
164     jbyteArray value) {
165   static_assert(sizeof(jbyte) == 1,
166                 "Require Java byte to be represented as a single byte");
167   TFE_Op* op = requireOp(env, op_handle);
168   if (op == nullptr) return;
169   const char* cname = env->GetStringUTFChars(attr_name, nullptr);
170   jbyte* cvalue = env->GetByteArrayElements(value, nullptr);
171   TFE_OpSetAttrString(op, cname, cvalue, env->GetArrayLength(value));
172   env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT);
173   env->ReleaseStringUTFChars(attr_name, cname);
174 }
175 
176 JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(JNIEnv * env,jclass object,jlong op_handle,jstring attr_name,jobjectArray values)177 Java_org_tensorflow_EagerOperationBuilder_setAttrStringList(
178     JNIEnv* env, jclass object, jlong op_handle, jstring attr_name,
179     jobjectArray values) {
180   TFE_Op* op = requireOp(env, op_handle);
181   if (op == nullptr) return;
182   const char* cname = env->GetStringUTFChars(attr_name, nullptr);
183   int num_values = env->GetArrayLength(values);
184   static_assert(sizeof(jbyte) == 1,
185                 "Require Java byte to be represented as a single byte");
186   std::unique_ptr<jbyteArray[]> jarrays(new jbyteArray[num_values]);
187   std::unique_ptr<jbyte*[]> jvalues(new jbyte*[num_values]);
188   std::unique_ptr<void*[]> cvalues(new void*[num_values]);
189   std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
190 
191   for (int i = 0; i < num_values; ++i) {
192     jbyteArray v =
193         static_cast<jbyteArray>(env->GetObjectArrayElement(values, i));
194     jarrays[i] = v;
195     jvalues[i] = env->GetByteArrayElements(v, nullptr);
196     cvalues[i] = jvalues[i];
197     lengths[i] = static_cast<size_t>(env->GetArrayLength(v));
198   }
199   TFE_OpSetAttrStringList(op, cname, cvalues.get(), lengths.get(), num_values);
200   for (int i = 0; i < num_values; ++i) {
201     env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT);
202   }
203   env->ReleaseStringUTFChars(attr_name, cname);
204 }
205 
206 #define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype)                       \
207   JNIEXPORT void JNICALL                                                 \
208       Java_org_tensorflow_EagerOperationBuilder_setAttr##name(           \
209           JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \
210           jtype value) {                                                 \
211     static_assert(                                                       \
212         sizeof(ctype) >= sizeof(jtype),                                  \
213         "Information loss when converting between Java and C types");    \
214     TFE_Op* op = requireOp(env, op_handle);                              \
215     if (op == nullptr) return;                                           \
216     const char* cname = env->GetStringUTFChars(attr_name, nullptr);      \
217     TFE_OpSetAttr##name(op, cname, static_cast<ctype>(value));           \
218     env->ReleaseStringUTFChars(attr_name, cname);                        \
219   }
220 
221 #define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)                  \
222   JNIEXPORT void JNICALL                                                 \
223       Java_org_tensorflow_EagerOperationBuilder_setAttr##name##List(     \
224           JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name, \
225           jtype##Array value) {                                          \
226     TFE_Op* op = requireOp(env, op_handle);                              \
227     if (op == nullptr) return;                                           \
228     const char* cname = env->GetStringUTFChars(attr_name, nullptr);      \
229     /* Make a copy of the array to paper over any differences */         \
230     /* in byte representations of the jtype and ctype */                 \
231     /* For example, jint vs TF_DataType. */                              \
232     /* If this copy turns out to be a problem in practice */             \
233     /* can avoid it for many types. */                                   \
234     const int n = env->GetArrayLength(value);                            \
235     std::unique_ptr<ctype[]> cvalue(new ctype[n]);                       \
236     jtype* elems = env->Get##jname##ArrayElements(value, nullptr);       \
237     for (int i = 0; i < n; ++i) {                                        \
238       cvalue[i] = static_cast<ctype>(elems[i]);                          \
239     }                                                                    \
240     TFE_OpSetAttr##name##List(op, cname, cvalue.get(), n);               \
241     env->Release##jname##ArrayElements(value, elems, JNI_ABORT);         \
242     env->ReleaseStringUTFChars(attr_name, cname);                        \
243   }
244 
245 #define DEFINE_SET_ATTR(name, jname, jtype, ctype) \
246   DEFINE_SET_ATTR_SCALAR(name, jtype, ctype)       \
247   DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)
248 
249 DEFINE_SET_ATTR(Int, Long, jlong, int64_t);
250 DEFINE_SET_ATTR(Float, Float, jfloat, float);
251 DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char);
252 DEFINE_SET_ATTR(Type, Int, jint, TF_DataType);
253 #undef DEFINE_SET_ATTR
254 #undef DEFINE_SET_ATTR_LIST
255 #undef DEFINE_SET_ATTR_SCALAR
256 
Java_org_tensorflow_EagerOperationBuilder_setAttrTensor(JNIEnv * env,jclass clazz,jlong handle,jstring attr_name,jlong tensor_handle)257 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrTensor(
258     JNIEnv* env, jclass clazz, jlong handle, jstring attr_name,
259     jlong tensor_handle) {
260   TFE_Op* op = requireOp(env, handle);
261   if (op == nullptr) return;
262   TF_Tensor* t = requireTensor(env, tensor_handle);
263   if (t == nullptr) return;
264   const char* cname = env->GetStringUTFChars(attr_name, nullptr);
265   TF_Status* status = TF_NewStatus();
266   TFE_OpSetAttrTensor(op, cname, t, status);
267   throwExceptionIfNotOK(env, status);
268   TF_DeleteStatus(status);
269   env->ReleaseStringUTFChars(attr_name, cname);
270 }
271 
Java_org_tensorflow_EagerOperationBuilder_setAttrShape(JNIEnv * env,jclass clazz,jlong op_handle,jstring attr_name,jlongArray shape,jint num_dims)272 JNIEXPORT void JNICALL Java_org_tensorflow_EagerOperationBuilder_setAttrShape(
273     JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
274     jlongArray shape, jint num_dims) {
275   TFE_Op* op = requireOp(env, op_handle);
276   if (op == nullptr) return;
277   std::unique_ptr<int64_t[]> cvalue;
278   // num_dims and env->GetArrayLength(shape) are assumed to be consistent.
279   // i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
280   if (num_dims > 0) {
281     cvalue.reset(new int64_t[num_dims]);
282     jlong* elems = env->GetLongArrayElements(shape, nullptr);
283     for (int i = 0; i < num_dims; ++i) {
284       cvalue[i] = static_cast<int64_t>(elems[i]);
285     }
286     env->ReleaseLongArrayElements(shape, elems, JNI_ABORT);
287   }
288   const char* cname = env->GetStringUTFChars(attr_name, nullptr);
289   TF_Status* status = TF_NewStatus();
290   TFE_OpSetAttrShape(op, cname, cvalue.get(), static_cast<int>(num_dims),
291                      status);
292   throwExceptionIfNotOK(env, status);
293   TF_DeleteStatus(status);
294   env->ReleaseStringUTFChars(attr_name, cname);
295 }
296 
297 JNIEXPORT void JNICALL
Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(JNIEnv * env,jclass clazz,jlong op_handle,jstring attr_name,jlongArray shapes,jintArray num_dims)298 Java_org_tensorflow_EagerOperationBuilder_setAttrShapeList(
299     JNIEnv* env, jclass clazz, jlong op_handle, jstring attr_name,
300     jlongArray shapes, jintArray num_dims) {
301   TFE_Op* op = requireOp(env, op_handle);
302   if (op == nullptr) return;
303   std::unique_ptr<int64_t[]> cshapes;
304   std::unique_ptr<const int64_t*[]> cdims;
305   std::unique_ptr<int[]> cnum_dims;
306   const int num_dims_length = env->GetArrayLength(num_dims);
307   if (num_dims_length > 0) {
308     const int shapes_length = env->GetArrayLength(shapes);
309     cshapes.reset(new int64_t[shapes_length]);
310     cdims.reset(new const int64_t*[num_dims_length]);
311     cnum_dims.reset(new int[num_dims_length]);
312     jlong* shapes_elems =
313         static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr));
314     std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
315     env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
316     int64_t* cshapes_ptr = cshapes.get();
317     jint* num_dims_elems =
318         static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr));
319     for (int i = 0; i < num_dims_length; ++i) {
320       cnum_dims[i] = static_cast<int>(num_dims_elems[i]);
321       cdims[i] = cshapes_ptr;
322       if (cnum_dims[i] > 0) {
323         cshapes_ptr += cnum_dims[i];
324       }
325     }
326     env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT);
327   }
328   const char* cname = env->GetStringUTFChars(attr_name, nullptr);
329   TF_Status* status = TF_NewStatus();
330   TFE_OpSetAttrShapeList(op, cname, cdims.get(), cnum_dims.get(),
331                          num_dims_length, status);
332   throwExceptionIfNotOK(env, status);
333   TF_DeleteStatus(status);
334   env->ReleaseStringUTFChars(attr_name, cname);
335 }
336