1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/java/src/main/native/operation_builder_jni.h"
17
18 #include <cstring>
19 #include <memory>
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/java/src/main/native/exception_jni.h"
22
23 namespace {
requireHandle(JNIEnv * env,jlong handle)24 TF_OperationDescription* requireHandle(JNIEnv* env, jlong handle) {
25 if (handle == 0) {
26 throwException(env, kIllegalStateException,
27 "Operation has already been built");
28 return nullptr;
29 }
30 return reinterpret_cast<TF_OperationDescription*>(handle);
31 }
32
resolveOutput(JNIEnv * env,jlong op_handle,jint index,TF_Output * out)33 bool resolveOutput(JNIEnv* env, jlong op_handle, jint index, TF_Output* out) {
34 if (op_handle == 0) {
35 throwException(env, kIllegalStateException,
36 "close() was called on the Graph");
37 return false;
38 }
39 out->oper = reinterpret_cast<TF_Operation*>(op_handle);
40 out->index = static_cast<int>(index);
41 return true;
42 }
43
requireTensor(JNIEnv * env,jlong handle)44 TF_Tensor* requireTensor(JNIEnv* env, jlong handle) {
45 if (handle == 0) {
46 throwException(env, kIllegalStateException,
47 "close() has been called on the Tensor");
48 return nullptr;
49 }
50 return reinterpret_cast<TF_Tensor*>(handle);
51 }
52 } // namespace
53
Java_org_tensorflow_OperationBuilder_allocate(JNIEnv * env,jclass clazz,jlong graph_handle,jstring type,jstring name)54 JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_allocate(
55 JNIEnv* env, jclass clazz, jlong graph_handle, jstring type, jstring name) {
56 if (graph_handle == 0) {
57 throwException(env, kIllegalStateException,
58 "close() has been called on the Graph");
59 return 0;
60 }
61 TF_Graph* graph = reinterpret_cast<TF_Graph*>(graph_handle);
62 const char* op_type = env->GetStringUTFChars(type, nullptr);
63 const char* op_name = env->GetStringUTFChars(name, nullptr);
64 TF_OperationDescription* d = TF_NewOperation(graph, op_type, op_name);
65 env->ReleaseStringUTFChars(name, op_name);
66 env->ReleaseStringUTFChars(type, op_type);
67 static_assert(sizeof(jlong) >= sizeof(TF_OperationDescription*),
68 "Cannot represent a C TF_OperationDescription as a Java long");
69 return reinterpret_cast<jlong>(d);
70 }
71
Java_org_tensorflow_OperationBuilder_finish(JNIEnv * env,jclass clazz,jlong handle)72 JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_finish(
73 JNIEnv* env, jclass clazz, jlong handle) {
74 TF_OperationDescription* d = requireHandle(env, handle);
75 if (d == nullptr) return 0;
76 TF_Status* status = TF_NewStatus();
77 TF_Operation* op = TF_FinishOperation(d, status);
78 if (throwExceptionIfNotOK(env, status)) {
79 TF_DeleteStatus(status);
80 return reinterpret_cast<jlong>(op);
81 }
82 TF_DeleteStatus(status);
83 return 0;
84 }
85
Java_org_tensorflow_OperationBuilder_addInput(JNIEnv * env,jclass clazz,jlong handle,jlong op_handle,jint index)86 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInput(
87 JNIEnv* env, jclass clazz, jlong handle, jlong op_handle, jint index) {
88 TF_Output out;
89 if (!resolveOutput(env, op_handle, index, &out)) return;
90 TF_OperationDescription* d = requireHandle(env, handle);
91 if (d == nullptr) return;
92 TF_AddInput(d, out);
93 }
94
Java_org_tensorflow_OperationBuilder_addInputList(JNIEnv * env,jclass clazz,jlong handle,jlongArray op_handles,jintArray indices)95 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInputList(
96 JNIEnv* env, jclass clazz, jlong handle, jlongArray op_handles,
97 jintArray indices) {
98 TF_OperationDescription* d = requireHandle(env, handle);
99 if (d == nullptr) return;
100 const size_t n = static_cast<size_t>(env->GetArrayLength(op_handles));
101 if (env->GetArrayLength(indices) != n) {
102 throwException(env, kIllegalArgumentException,
103 "mismatch in number of Operations (%d) and output indices "
104 "(%d) provided",
105 n, env->GetArrayLength(indices));
106 return;
107 }
108 std::unique_ptr<TF_Output[]> o(new TF_Output[n]);
109 jlong* oph = env->GetLongArrayElements(op_handles, nullptr);
110 jint* idx = env->GetIntArrayElements(indices, nullptr);
111 bool ok = true;
112 for (int i = 0; i < n && ok; ++i) {
113 ok = resolveOutput(env, oph[i], idx[i], &o[i]);
114 }
115 env->ReleaseIntArrayElements(indices, idx, JNI_ABORT);
116 env->ReleaseLongArrayElements(op_handles, oph, JNI_ABORT);
117 if (!ok) return;
118 TF_AddInputList(d, o.get(), n);
119 }
120
Java_org_tensorflow_OperationBuilder_addControlInput(JNIEnv * env,jclass clazz,jlong handle,jlong op_handle)121 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addControlInput(
122 JNIEnv* env, jclass clazz, jlong handle, jlong op_handle) {
123 if (op_handle == 0) {
124 throwException(env, kIllegalStateException,
125 "control input is not valid, "
126 "perhaps the Graph containing it has been closed()?");
127 return;
128 }
129 TF_Operation* control = reinterpret_cast<TF_Operation*>(op_handle);
130 TF_OperationDescription* d = requireHandle(env, handle);
131 if (d == nullptr) return;
132 TF_AddControlInput(d, control);
133 }
134
Java_org_tensorflow_OperationBuilder_setDevice(JNIEnv * env,jclass clazz,jlong handle,jstring device)135 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setDevice(
136 JNIEnv* env, jclass clazz, jlong handle, jstring device) {
137 TF_OperationDescription* d = requireHandle(env, handle);
138 if (d == nullptr) return;
139 const char* cdevice = env->GetStringUTFChars(device, nullptr);
140 TF_SetDevice(d, cdevice);
141 env->ReleaseStringUTFChars(device, cdevice);
142 }
143
Java_org_tensorflow_OperationBuilder_setAttrString(JNIEnv * env,jclass clazz,jlong handle,jstring name,jbyteArray value)144 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrString(
145 JNIEnv* env, jclass clazz, jlong handle, jstring name, jbyteArray value) {
146 static_assert(sizeof(jbyte) == 1,
147 "Require Java byte to be represented as a single byte");
148 TF_OperationDescription* d = requireHandle(env, handle);
149 if (d == nullptr) return;
150 const char* cname = env->GetStringUTFChars(name, nullptr);
151 jbyte* cvalue = env->GetByteArrayElements(value, nullptr);
152 TF_SetAttrString(d, cname, cvalue, env->GetArrayLength(value));
153 env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT);
154 env->ReleaseStringUTFChars(name, cname);
155 }
156
157 #define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
158 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttr##name( \
159 JNIEnv* env, jclass clazz, jlong handle, jstring name, jtype value) { \
160 static_assert( \
161 sizeof(ctype) >= sizeof(jtype), \
162 "Information loss when converting between Java and C types"); \
163 TF_OperationDescription* d = requireHandle(env, handle); \
164 if (d == nullptr) return; \
165 const char* cname = env->GetStringUTFChars(name, nullptr); \
166 TF_SetAttr##name(d, cname, static_cast<ctype>(value)); \
167 env->ReleaseStringUTFChars(name, cname); \
168 }
169
170 #define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \
171 JNIEXPORT void JNICALL \
172 Java_org_tensorflow_OperationBuilder_setAttr##name##List( \
173 JNIEnv* env, jclass clazz, jlong handle, jstring name, \
174 jtype##Array value) { \
175 TF_OperationDescription* d = requireHandle(env, handle); \
176 if (d == nullptr) return; \
177 const char* cname = env->GetStringUTFChars(name, nullptr); \
178 /* Make a copy of the array to paper over any differences */ \
179 /* in byte representations of the jtype and ctype */ \
180 /* For example, jint vs TF_DataType. */ \
181 /* If this copy turns out to be a problem in practice */ \
182 /* can avoid it for many types. */ \
183 const int n = env->GetArrayLength(value); \
184 std::unique_ptr<ctype[]> cvalue(new ctype[n]); \
185 jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \
186 for (int i = 0; i < n; ++i) { \
187 cvalue[i] = static_cast<ctype>(elems[i]); \
188 } \
189 TF_SetAttr##name##List(d, cname, cvalue.get(), n); \
190 env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \
191 env->ReleaseStringUTFChars(name, cname); \
192 }
193
194 #define DEFINE_SET_ATTR(name, jname, jtype, ctype) \
195 DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
196 DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)
197
198 DEFINE_SET_ATTR(Int, Long, jlong, int64_t);
199 DEFINE_SET_ATTR(Float, Float, jfloat, float);
200 DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char);
201 DEFINE_SET_ATTR(Type, Int, jint, TF_DataType);
202 #undef DEFINE_SET_ATTR
203 #undef DEFINE_SET_ATTR_LIST
204 #undef DEFINE_SET_ATTR_SCALAR
205
Java_org_tensorflow_OperationBuilder_setAttrTensor(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlong tensor_handle)206 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensor(
207 JNIEnv* env, jclass clazz, jlong handle, jstring name,
208 jlong tensor_handle) {
209 TF_OperationDescription* d = requireHandle(env, handle);
210 if (d == nullptr) return;
211 TF_Tensor* t = requireTensor(env, tensor_handle);
212 if (t == nullptr) return;
213 const char* cname = env->GetStringUTFChars(name, nullptr);
214 TF_Status* status = TF_NewStatus();
215 TF_SetAttrTensor(d, cname, t, status);
216 throwExceptionIfNotOK(env, status);
217 TF_DeleteStatus(status);
218 env->ReleaseStringUTFChars(name, cname);
219 }
220
Java_org_tensorflow_OperationBuilder_setAttrTensorList(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray tensor_handles)221 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensorList(
222 JNIEnv* env, jclass clazz, jlong handle, jstring name,
223 jlongArray tensor_handles) {
224 TF_OperationDescription* d = requireHandle(env, handle);
225 if (d == nullptr) return;
226 const int n = env->GetArrayLength(tensor_handles);
227 std::unique_ptr<TF_Tensor* []> tensors(new TF_Tensor*[n]);
228 jlong* jhandles = env->GetLongArrayElements(tensor_handles, nullptr);
229 bool ok = true;
230 for (int i = 0; i < n && ok; ++i) {
231 tensors[i] = requireTensor(env, jhandles[i]);
232 ok = !env->ExceptionCheck();
233 }
234 env->ReleaseLongArrayElements(tensor_handles, jhandles, JNI_ABORT);
235 if (!ok) return;
236
237 const char* cname = env->GetStringUTFChars(name, nullptr);
238 TF_Status* status = TF_NewStatus();
239 TF_SetAttrTensorList(d, cname, tensors.get(), n, status);
240 throwExceptionIfNotOK(env, status);
241 TF_DeleteStatus(status);
242 env->ReleaseStringUTFChars(name, cname);
243 }
244
Java_org_tensorflow_OperationBuilder_setAttrShape(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray shape,jint num_dims)245 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShape(
246 JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shape,
247 jint num_dims) {
248 TF_OperationDescription* d = requireHandle(env, handle);
249 if (d == nullptr) return;
250 std::unique_ptr<int64_t[]> cvalue;
251 // num_dims and env->GetArrayLength(shape) are assumed to be consistent.
252 // i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
253 if (num_dims > 0) {
254 cvalue.reset(new int64_t[num_dims]);
255 jlong* elems = env->GetLongArrayElements(shape, nullptr);
256 for (int i = 0; i < num_dims; ++i) {
257 cvalue[i] = static_cast<int64_t>(elems[i]);
258 }
259 env->ReleaseLongArrayElements(shape, elems, JNI_ABORT);
260 }
261 const char* cname = env->GetStringUTFChars(name, nullptr);
262 TF_SetAttrShape(d, cname, cvalue.get(), static_cast<int>(num_dims));
263 env->ReleaseStringUTFChars(name, cname);
264 }
265
Java_org_tensorflow_OperationBuilder_setAttrShapeList(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray shapes,jintArray num_dims)266 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShapeList(
267 JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shapes,
268 jintArray num_dims) {
269 TF_OperationDescription* d = requireHandle(env, handle);
270 if (d == nullptr) return;
271 std::unique_ptr<int64_t[]> cshapes;
272 std::unique_ptr<int64_t* []> cdims;
273 std::unique_ptr<int[]> cnum_dims;
274 const int num_dims_length = env->GetArrayLength(num_dims);
275 if (num_dims_length > 0) {
276 const int shapes_length = env->GetArrayLength(shapes);
277 cshapes.reset(new int64_t[shapes_length]);
278 cdims.reset(new int64_t*[num_dims_length]);
279 cnum_dims.reset(new int[num_dims_length]);
280 jlong* shapes_elems =
281 static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr));
282 std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
283 env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
284 int64_t* cshapes_ptr = cshapes.get();
285 jint* num_dims_elems =
286 static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr));
287 for (int i = 0; i < num_dims_length; ++i) {
288 cnum_dims[i] = static_cast<int>(num_dims_elems[i]);
289 cdims[i] = cshapes_ptr;
290 if (cnum_dims[i] > 0) {
291 cshapes_ptr += cnum_dims[i];
292 }
293 }
294 env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT);
295 }
296 const char* cname = env->GetStringUTFChars(name, nullptr);
297 TF_SetAttrShapeList(d, cname, cdims.get(), cnum_dims.get(), num_dims_length);
298 env->ReleaseStringUTFChars(name, cname);
299 }
300
Java_org_tensorflow_OperationBuilder_setAttrStringList(JNIEnv * env,jclass object,jlong handle,jstring name,jobjectArray values)301 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrStringList(
302 JNIEnv* env, jclass object, jlong handle, jstring name,
303 jobjectArray values) {
304 TF_OperationDescription* d = requireHandle(env, handle);
305 if (d == nullptr) return;
306 const char* cname = env->GetStringUTFChars(name, nullptr);
307 int num_values = env->GetArrayLength(values);
308 static_assert(sizeof(jbyte) == 1,
309 "Require Java byte to be represented as a single byte");
310 std::unique_ptr<jbyteArray[]> jarrays(new jbyteArray[num_values]);
311 std::unique_ptr<jbyte* []> jvalues(new jbyte*[num_values]);
312 std::unique_ptr<void* []> cvalues(new void*[num_values]);
313 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
314
315 for (int i = 0; i < num_values; ++i) {
316 jbyteArray v =
317 static_cast<jbyteArray>(env->GetObjectArrayElement(values, i));
318 jarrays[i] = v;
319 jvalues[i] = env->GetByteArrayElements(v, nullptr);
320 cvalues[i] = jvalues[i];
321 lengths[i] = static_cast<size_t>(env->GetArrayLength(v));
322 }
323 TF_SetAttrStringList(d, cname, cvalues.get(), lengths.get(), num_values);
324 for (int i = 0; i < num_values; ++i) {
325 env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT);
326 }
327 env->ReleaseStringUTFChars(name, cname);
328 }
329