• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/java/src/main/native/operation_builder_jni.h"
17 
18 #include <cstring>
19 #include <memory>
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/java/src/main/native/exception_jni.h"
22 
23 namespace {
requireHandle(JNIEnv * env,jlong handle)24 TF_OperationDescription* requireHandle(JNIEnv* env, jlong handle) {
25   if (handle == 0) {
26     throwException(env, kIllegalStateException,
27                    "Operation has already been built");
28     return nullptr;
29   }
30   return reinterpret_cast<TF_OperationDescription*>(handle);
31 }
32 
resolveOutput(JNIEnv * env,jlong op_handle,jint index,TF_Output * out)33 bool resolveOutput(JNIEnv* env, jlong op_handle, jint index, TF_Output* out) {
34   if (op_handle == 0) {
35     throwException(env, kIllegalStateException,
36                    "close() was called on the Graph");
37     return false;
38   }
39   out->oper = reinterpret_cast<TF_Operation*>(op_handle);
40   out->index = static_cast<int>(index);
41   return true;
42 }
43 
requireTensor(JNIEnv * env,jlong handle)44 TF_Tensor* requireTensor(JNIEnv* env, jlong handle) {
45   if (handle == 0) {
46     throwException(env, kIllegalStateException,
47                    "close() has been called on the Tensor");
48     return nullptr;
49   }
50   return reinterpret_cast<TF_Tensor*>(handle);
51 }
52 }  // namespace
53 
Java_org_tensorflow_OperationBuilder_allocate(JNIEnv * env,jclass clazz,jlong graph_handle,jstring type,jstring name)54 JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_allocate(
55     JNIEnv* env, jclass clazz, jlong graph_handle, jstring type, jstring name) {
56   if (graph_handle == 0) {
57     throwException(env, kIllegalStateException,
58                    "close() has been called on the Graph");
59     return 0;
60   }
61   TF_Graph* graph = reinterpret_cast<TF_Graph*>(graph_handle);
62   const char* op_type = env->GetStringUTFChars(type, nullptr);
63   const char* op_name = env->GetStringUTFChars(name, nullptr);
64   TF_OperationDescription* d = TF_NewOperation(graph, op_type, op_name);
65   env->ReleaseStringUTFChars(name, op_name);
66   env->ReleaseStringUTFChars(type, op_type);
67   static_assert(sizeof(jlong) >= sizeof(TF_OperationDescription*),
68                 "Cannot represent a C TF_OperationDescription as a Java long");
69   return reinterpret_cast<jlong>(d);
70 }
71 
Java_org_tensorflow_OperationBuilder_finish(JNIEnv * env,jclass clazz,jlong handle)72 JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_finish(
73     JNIEnv* env, jclass clazz, jlong handle) {
74   TF_OperationDescription* d = requireHandle(env, handle);
75   if (d == nullptr) return 0;
76   TF_Status* status = TF_NewStatus();
77   TF_Operation* op = TF_FinishOperation(d, status);
78   if (throwExceptionIfNotOK(env, status)) {
79     TF_DeleteStatus(status);
80     return reinterpret_cast<jlong>(op);
81   }
82   TF_DeleteStatus(status);
83   return 0;
84 }
85 
Java_org_tensorflow_OperationBuilder_addInput(JNIEnv * env,jclass clazz,jlong handle,jlong op_handle,jint index)86 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInput(
87     JNIEnv* env, jclass clazz, jlong handle, jlong op_handle, jint index) {
88   TF_Output out;
89   if (!resolveOutput(env, op_handle, index, &out)) return;
90   TF_OperationDescription* d = requireHandle(env, handle);
91   if (d == nullptr) return;
92   TF_AddInput(d, out);
93 }
94 
Java_org_tensorflow_OperationBuilder_addInputList(JNIEnv * env,jclass clazz,jlong handle,jlongArray op_handles,jintArray indices)95 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInputList(
96     JNIEnv* env, jclass clazz, jlong handle, jlongArray op_handles,
97     jintArray indices) {
98   TF_OperationDescription* d = requireHandle(env, handle);
99   if (d == nullptr) return;
100   const size_t n = static_cast<size_t>(env->GetArrayLength(op_handles));
101   if (env->GetArrayLength(indices) != n) {
102     throwException(env, kIllegalArgumentException,
103                    "mismatch in number of Operations (%d) and output indices "
104                    "(%d) provided",
105                    n, env->GetArrayLength(indices));
106     return;
107   }
108   std::unique_ptr<TF_Output[]> o(new TF_Output[n]);
109   jlong* oph = env->GetLongArrayElements(op_handles, nullptr);
110   jint* idx = env->GetIntArrayElements(indices, nullptr);
111   bool ok = true;
112   for (int i = 0; i < n && ok; ++i) {
113     ok = resolveOutput(env, oph[i], idx[i], &o[i]);
114   }
115   env->ReleaseIntArrayElements(indices, idx, JNI_ABORT);
116   env->ReleaseLongArrayElements(op_handles, oph, JNI_ABORT);
117   if (!ok) return;
118   TF_AddInputList(d, o.get(), n);
119 }
120 
Java_org_tensorflow_OperationBuilder_addControlInput(JNIEnv * env,jclass clazz,jlong handle,jlong op_handle)121 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addControlInput(
122     JNIEnv* env, jclass clazz, jlong handle, jlong op_handle) {
123   if (op_handle == 0) {
124     throwException(env, kIllegalStateException,
125                    "control input is not valid, "
126                    "perhaps the Graph containing it has been closed()?");
127     return;
128   }
129   TF_Operation* control = reinterpret_cast<TF_Operation*>(op_handle);
130   TF_OperationDescription* d = requireHandle(env, handle);
131   if (d == nullptr) return;
132   TF_AddControlInput(d, control);
133 }
134 
Java_org_tensorflow_OperationBuilder_setDevice(JNIEnv * env,jclass clazz,jlong handle,jstring device)135 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setDevice(
136     JNIEnv* env, jclass clazz, jlong handle, jstring device) {
137   TF_OperationDescription* d = requireHandle(env, handle);
138   if (d == nullptr) return;
139   const char* cdevice = env->GetStringUTFChars(device, nullptr);
140   TF_SetDevice(d, cdevice);
141   env->ReleaseStringUTFChars(device, cdevice);
142 }
143 
Java_org_tensorflow_OperationBuilder_setAttrString(JNIEnv * env,jclass clazz,jlong handle,jstring name,jbyteArray value)144 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrString(
145     JNIEnv* env, jclass clazz, jlong handle, jstring name, jbyteArray value) {
146   static_assert(sizeof(jbyte) == 1,
147                 "Require Java byte to be represented as a single byte");
148   TF_OperationDescription* d = requireHandle(env, handle);
149   if (d == nullptr) return;
150   const char* cname = env->GetStringUTFChars(name, nullptr);
151   jbyte* cvalue = env->GetByteArrayElements(value, nullptr);
152   TF_SetAttrString(d, cname, cvalue, env->GetArrayLength(value));
153   env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT);
154   env->ReleaseStringUTFChars(name, cname);
155 }
156 
157 #define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype)                           \
158   JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttr##name( \
159       JNIEnv* env, jclass clazz, jlong handle, jstring name, jtype value) {  \
160     static_assert(                                                           \
161         sizeof(ctype) >= sizeof(jtype),                                      \
162         "Information loss when converting between Java and C types");        \
163     TF_OperationDescription* d = requireHandle(env, handle);                 \
164     if (d == nullptr) return;                                                \
165     const char* cname = env->GetStringUTFChars(name, nullptr);               \
166     TF_SetAttr##name(d, cname, static_cast<ctype>(value));                   \
167     env->ReleaseStringUTFChars(name, cname);                                 \
168   }
169 
170 #define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)            \
171   JNIEXPORT void JNICALL                                           \
172       Java_org_tensorflow_OperationBuilder_setAttr##name##List(    \
173           JNIEnv* env, jclass clazz, jlong handle, jstring name,   \
174           jtype##Array value) {                                    \
175     TF_OperationDescription* d = requireHandle(env, handle);       \
176     if (d == nullptr) return;                                      \
177     const char* cname = env->GetStringUTFChars(name, nullptr);     \
178     /* Make a copy of the array to paper over any differences */   \
179     /* in byte representations of the jtype and ctype         */   \
180     /* For example, jint vs TF_DataType.                      */   \
181     /* If this copy turns out to be a problem in practice     */   \
182     /* can avoid it for many types.                           */   \
183     const int n = env->GetArrayLength(value);                      \
184     std::unique_ptr<ctype[]> cvalue(new ctype[n]);                 \
185     jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \
186     for (int i = 0; i < n; ++i) {                                  \
187       cvalue[i] = static_cast<ctype>(elems[i]);                    \
188     }                                                              \
189     TF_SetAttr##name##List(d, cname, cvalue.get(), n);             \
190     env->Release##jname##ArrayElements(value, elems, JNI_ABORT);   \
191     env->ReleaseStringUTFChars(name, cname);                       \
192   }
193 
194 #define DEFINE_SET_ATTR(name, jname, jtype, ctype) \
195   DEFINE_SET_ATTR_SCALAR(name, jtype, ctype)       \
196   DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)
197 
198 DEFINE_SET_ATTR(Int, Long, jlong, int64_t);
199 DEFINE_SET_ATTR(Float, Float, jfloat, float);
200 DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char);
201 DEFINE_SET_ATTR(Type, Int, jint, TF_DataType);
202 #undef DEFINE_SET_ATTR
203 #undef DEFINE_SET_ATTR_LIST
204 #undef DEFINE_SET_ATTR_SCALAR
205 
Java_org_tensorflow_OperationBuilder_setAttrTensor(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlong tensor_handle)206 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensor(
207     JNIEnv* env, jclass clazz, jlong handle, jstring name,
208     jlong tensor_handle) {
209   TF_OperationDescription* d = requireHandle(env, handle);
210   if (d == nullptr) return;
211   TF_Tensor* t = requireTensor(env, tensor_handle);
212   if (t == nullptr) return;
213   const char* cname = env->GetStringUTFChars(name, nullptr);
214   TF_Status* status = TF_NewStatus();
215   TF_SetAttrTensor(d, cname, t, status);
216   throwExceptionIfNotOK(env, status);
217   TF_DeleteStatus(status);
218   env->ReleaseStringUTFChars(name, cname);
219 }
220 
Java_org_tensorflow_OperationBuilder_setAttrTensorList(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray tensor_handles)221 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensorList(
222     JNIEnv* env, jclass clazz, jlong handle, jstring name,
223     jlongArray tensor_handles) {
224   TF_OperationDescription* d = requireHandle(env, handle);
225   if (d == nullptr) return;
226   const int n = env->GetArrayLength(tensor_handles);
227   std::unique_ptr<TF_Tensor* []> tensors(new TF_Tensor*[n]);
228   jlong* jhandles = env->GetLongArrayElements(tensor_handles, nullptr);
229   bool ok = true;
230   for (int i = 0; i < n && ok; ++i) {
231     tensors[i] = requireTensor(env, jhandles[i]);
232     ok = !env->ExceptionCheck();
233   }
234   env->ReleaseLongArrayElements(tensor_handles, jhandles, JNI_ABORT);
235   if (!ok) return;
236 
237   const char* cname = env->GetStringUTFChars(name, nullptr);
238   TF_Status* status = TF_NewStatus();
239   TF_SetAttrTensorList(d, cname, tensors.get(), n, status);
240   throwExceptionIfNotOK(env, status);
241   TF_DeleteStatus(status);
242   env->ReleaseStringUTFChars(name, cname);
243 }
244 
Java_org_tensorflow_OperationBuilder_setAttrShape(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray shape,jint num_dims)245 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShape(
246     JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shape,
247     jint num_dims) {
248   TF_OperationDescription* d = requireHandle(env, handle);
249   if (d == nullptr) return;
250   std::unique_ptr<int64_t[]> cvalue;
251   // num_dims and env->GetArrayLength(shape) are assumed to be consistent.
252   // i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
253   if (num_dims > 0) {
254     cvalue.reset(new int64_t[num_dims]);
255     jlong* elems = env->GetLongArrayElements(shape, nullptr);
256     for (int i = 0; i < num_dims; ++i) {
257       cvalue[i] = static_cast<int64_t>(elems[i]);
258     }
259     env->ReleaseLongArrayElements(shape, elems, JNI_ABORT);
260   }
261   const char* cname = env->GetStringUTFChars(name, nullptr);
262   TF_SetAttrShape(d, cname, cvalue.get(), static_cast<int>(num_dims));
263   env->ReleaseStringUTFChars(name, cname);
264 }
265 
Java_org_tensorflow_OperationBuilder_setAttrShapeList(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray shapes,jintArray num_dims)266 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShapeList(
267     JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shapes,
268     jintArray num_dims) {
269   TF_OperationDescription* d = requireHandle(env, handle);
270   if (d == nullptr) return;
271   std::unique_ptr<int64_t[]> cshapes;
272   std::unique_ptr<int64_t* []> cdims;
273   std::unique_ptr<int[]> cnum_dims;
274   const int num_dims_length = env->GetArrayLength(num_dims);
275   if (num_dims_length > 0) {
276     const int shapes_length = env->GetArrayLength(shapes);
277     cshapes.reset(new int64_t[shapes_length]);
278     cdims.reset(new int64_t*[num_dims_length]);
279     cnum_dims.reset(new int[num_dims_length]);
280     jlong* shapes_elems =
281         static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr));
282     std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
283     env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
284     int64_t* cshapes_ptr = cshapes.get();
285     jint* num_dims_elems =
286         static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr));
287     for (int i = 0; i < num_dims_length; ++i) {
288       cnum_dims[i] = static_cast<int>(num_dims_elems[i]);
289       cdims[i] = cshapes_ptr;
290       if (cnum_dims[i] > 0) {
291         cshapes_ptr += cnum_dims[i];
292       }
293     }
294     env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT);
295   }
296   const char* cname = env->GetStringUTFChars(name, nullptr);
297   TF_SetAttrShapeList(d, cname, cdims.get(), cnum_dims.get(), num_dims_length);
298   env->ReleaseStringUTFChars(name, cname);
299 }
300 
Java_org_tensorflow_OperationBuilder_setAttrStringList(JNIEnv * env,jclass object,jlong handle,jstring name,jobjectArray values)301 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrStringList(
302     JNIEnv* env, jclass object, jlong handle, jstring name,
303     jobjectArray values) {
304   TF_OperationDescription* d = requireHandle(env, handle);
305   if (d == nullptr) return;
306   const char* cname = env->GetStringUTFChars(name, nullptr);
307   int num_values = env->GetArrayLength(values);
308   static_assert(sizeof(jbyte) == 1,
309                 "Require Java byte to be represented as a single byte");
310   std::unique_ptr<jbyteArray[]> jarrays(new jbyteArray[num_values]);
311   std::unique_ptr<jbyte* []> jvalues(new jbyte*[num_values]);
312   std::unique_ptr<void* []> cvalues(new void*[num_values]);
313   std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
314 
315   for (int i = 0; i < num_values; ++i) {
316     jbyteArray v =
317         static_cast<jbyteArray>(env->GetObjectArrayElement(values, i));
318     jarrays[i] = v;
319     jvalues[i] = env->GetByteArrayElements(v, nullptr);
320     cvalues[i] = jvalues[i];
321     lengths[i] = static_cast<size_t>(env->GetArrayLength(v));
322   }
323   TF_SetAttrStringList(d, cname, cvalues.get(), lengths.get(), num_values);
324   for (int i = 0; i < num_values; ++i) {
325     env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT);
326   }
327   env->ReleaseStringUTFChars(name, cname);
328 }
329