1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/java/src/main/native/graph_operation_builder_jni.h"
17 #include <cstring>
18 #include <memory>
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/java/src/main/native/exception_jni.h"
21
22 namespace {
requireHandle(JNIEnv * env,jlong handle)23 TF_OperationDescription* requireHandle(JNIEnv* env, jlong handle) {
24 if (handle == 0) {
25 throwException(env, kIllegalStateException,
26 "Operation has already been built");
27 return nullptr;
28 }
29 return reinterpret_cast<TF_OperationDescription*>(handle);
30 }
31
resolveOutput(JNIEnv * env,jlong op_handle,jint index,TF_Output * out)32 bool resolveOutput(JNIEnv* env, jlong op_handle, jint index, TF_Output* out) {
33 if (op_handle == 0) {
34 throwException(env, kIllegalStateException,
35 "close() was called on the Graph");
36 return false;
37 }
38 out->oper = reinterpret_cast<TF_Operation*>(op_handle);
39 out->index = static_cast<int>(index);
40 return true;
41 }
42
requireTensor(JNIEnv * env,jlong handle)43 TF_Tensor* requireTensor(JNIEnv* env, jlong handle) {
44 if (handle == 0) {
45 throwException(env, kIllegalStateException,
46 "close() has been called on the Tensor");
47 return nullptr;
48 }
49 return reinterpret_cast<TF_Tensor*>(handle);
50 }
51 } // namespace
52
Java_org_tensorflow_GraphOperationBuilder_allocate(JNIEnv * env,jclass clazz,jlong graph_handle,jstring type,jstring name)53 JNIEXPORT jlong JNICALL Java_org_tensorflow_GraphOperationBuilder_allocate(
54 JNIEnv* env, jclass clazz, jlong graph_handle, jstring type, jstring name) {
55 if (graph_handle == 0) {
56 throwException(env, kIllegalStateException,
57 "close() has been called on the Graph");
58 return 0;
59 }
60 TF_Graph* graph = reinterpret_cast<TF_Graph*>(graph_handle);
61 const char* op_type = env->GetStringUTFChars(type, nullptr);
62 const char* op_name = env->GetStringUTFChars(name, nullptr);
63 TF_OperationDescription* d = TF_NewOperation(graph, op_type, op_name);
64 env->ReleaseStringUTFChars(name, op_name);
65 env->ReleaseStringUTFChars(type, op_type);
66 static_assert(sizeof(jlong) >= sizeof(TF_OperationDescription*),
67 "Cannot represent a C TF_OperationDescription as a Java long");
68 return reinterpret_cast<jlong>(d);
69 }
70
Java_org_tensorflow_GraphOperationBuilder_finish(JNIEnv * env,jclass clazz,jlong handle)71 JNIEXPORT jlong JNICALL Java_org_tensorflow_GraphOperationBuilder_finish(
72 JNIEnv* env, jclass clazz, jlong handle) {
73 TF_OperationDescription* d = requireHandle(env, handle);
74 if (d == nullptr) return 0;
75 TF_Status* status = TF_NewStatus();
76 TF_Operation* op = TF_FinishOperation(d, status);
77 if (throwExceptionIfNotOK(env, status)) {
78 TF_DeleteStatus(status);
79 return reinterpret_cast<jlong>(op);
80 }
81 TF_DeleteStatus(status);
82 return 0;
83 }
84
Java_org_tensorflow_GraphOperationBuilder_addInput(JNIEnv * env,jclass clazz,jlong handle,jlong op_handle,jint index)85 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInput(
86 JNIEnv* env, jclass clazz, jlong handle, jlong op_handle, jint index) {
87 TF_Output out;
88 if (!resolveOutput(env, op_handle, index, &out)) return;
89 TF_OperationDescription* d = requireHandle(env, handle);
90 if (d == nullptr) return;
91 TF_AddInput(d, out);
92 }
93
Java_org_tensorflow_GraphOperationBuilder_addInputList(JNIEnv * env,jclass clazz,jlong handle,jlongArray op_handles,jintArray indices)94 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInputList(
95 JNIEnv* env, jclass clazz, jlong handle, jlongArray op_handles,
96 jintArray indices) {
97 TF_OperationDescription* d = requireHandle(env, handle);
98 if (d == nullptr) return;
99 const size_t n = static_cast<size_t>(env->GetArrayLength(op_handles));
100 if (env->GetArrayLength(indices) != n) {
101 throwException(env, kIllegalArgumentException,
102 "mismatch in number of Operations (%d) and output indices "
103 "(%d) provided",
104 n, env->GetArrayLength(indices));
105 return;
106 }
107 std::unique_ptr<TF_Output[]> o(new TF_Output[n]);
108 jlong* oph = env->GetLongArrayElements(op_handles, nullptr);
109 jint* idx = env->GetIntArrayElements(indices, nullptr);
110 bool ok = true;
111 for (int i = 0; i < n && ok; ++i) {
112 ok = resolveOutput(env, oph[i], idx[i], &o[i]);
113 }
114 env->ReleaseIntArrayElements(indices, idx, JNI_ABORT);
115 env->ReleaseLongArrayElements(op_handles, oph, JNI_ABORT);
116 if (!ok) return;
117 TF_AddInputList(d, o.get(), n);
118 }
119
120 JNIEXPORT void JNICALL
Java_org_tensorflow_GraphOperationBuilder_addControlInput(JNIEnv * env,jclass clazz,jlong handle,jlong op_handle)121 Java_org_tensorflow_GraphOperationBuilder_addControlInput(JNIEnv* env,
122 jclass clazz,
123 jlong handle,
124 jlong op_handle) {
125 if (op_handle == 0) {
126 throwException(env, kIllegalStateException,
127 "control input is not valid, "
128 "perhaps the Graph containing it has been closed()?");
129 return;
130 }
131 TF_Operation* control = reinterpret_cast<TF_Operation*>(op_handle);
132 TF_OperationDescription* d = requireHandle(env, handle);
133 if (d == nullptr) return;
134 TF_AddControlInput(d, control);
135 }
136
Java_org_tensorflow_GraphOperationBuilder_setDevice(JNIEnv * env,jclass clazz,jlong handle,jstring device)137 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setDevice(
138 JNIEnv* env, jclass clazz, jlong handle, jstring device) {
139 TF_OperationDescription* d = requireHandle(env, handle);
140 if (d == nullptr) return;
141 const char* cdevice = env->GetStringUTFChars(device, nullptr);
142 TF_SetDevice(d, cdevice);
143 env->ReleaseStringUTFChars(device, cdevice);
144 }
145
Java_org_tensorflow_GraphOperationBuilder_setAttrString(JNIEnv * env,jclass clazz,jlong handle,jstring name,jbyteArray value)146 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrString(
147 JNIEnv* env, jclass clazz, jlong handle, jstring name, jbyteArray value) {
148 static_assert(sizeof(jbyte) == 1,
149 "Require Java byte to be represented as a single byte");
150 TF_OperationDescription* d = requireHandle(env, handle);
151 if (d == nullptr) return;
152 const char* cname = env->GetStringUTFChars(name, nullptr);
153 jbyte* cvalue = env->GetByteArrayElements(value, nullptr);
154 TF_SetAttrString(d, cname, cvalue, env->GetArrayLength(value));
155 env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT);
156 env->ReleaseStringUTFChars(name, cname);
157 }
158
159 #define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
160 JNIEXPORT void JNICALL \
161 Java_org_tensorflow_GraphOperationBuilder_setAttr##name( \
162 JNIEnv* env, jclass clazz, jlong handle, jstring name, \
163 jtype value) { \
164 static_assert( \
165 sizeof(ctype) >= sizeof(jtype), \
166 "Information loss when converting between Java and C types"); \
167 TF_OperationDescription* d = requireHandle(env, handle); \
168 if (d == nullptr) return; \
169 const char* cname = env->GetStringUTFChars(name, nullptr); \
170 TF_SetAttr##name(d, cname, static_cast<ctype>(value)); \
171 env->ReleaseStringUTFChars(name, cname); \
172 }
173
174 #define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype) \
175 JNIEXPORT void JNICALL \
176 Java_org_tensorflow_GraphOperationBuilder_setAttr##name##List( \
177 JNIEnv* env, jclass clazz, jlong handle, jstring name, \
178 jtype##Array value) { \
179 TF_OperationDescription* d = requireHandle(env, handle); \
180 if (d == nullptr) return; \
181 const char* cname = env->GetStringUTFChars(name, nullptr); \
182 /* Make a copy of the array to paper over any differences */ \
183 /* in byte representations of the jtype and ctype */ \
184 /* For example, jint vs TF_DataType. */ \
185 /* If this copy turns out to be a problem in practice */ \
186 /* can avoid it for many types. */ \
187 const int n = env->GetArrayLength(value); \
188 std::unique_ptr<ctype[]> cvalue(new ctype[n]); \
189 jtype* elems = env->Get##jname##ArrayElements(value, nullptr); \
190 for (int i = 0; i < n; ++i) { \
191 cvalue[i] = static_cast<ctype>(elems[i]); \
192 } \
193 TF_SetAttr##name##List(d, cname, cvalue.get(), n); \
194 env->Release##jname##ArrayElements(value, elems, JNI_ABORT); \
195 env->ReleaseStringUTFChars(name, cname); \
196 }
197
198 #define DEFINE_SET_ATTR(name, jname, jtype, ctype) \
199 DEFINE_SET_ATTR_SCALAR(name, jtype, ctype) \
200 DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)
201
202 DEFINE_SET_ATTR(Int, Long, jlong, int64_t);
203 DEFINE_SET_ATTR(Float, Float, jfloat, float);
204 DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char);
205 DEFINE_SET_ATTR(Type, Int, jint, TF_DataType);
206 #undef DEFINE_SET_ATTR
207 #undef DEFINE_SET_ATTR_LIST
208 #undef DEFINE_SET_ATTR_SCALAR
209
Java_org_tensorflow_GraphOperationBuilder_setAttrTensor(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlong tensor_handle)210 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrTensor(
211 JNIEnv* env, jclass clazz, jlong handle, jstring name,
212 jlong tensor_handle) {
213 TF_OperationDescription* d = requireHandle(env, handle);
214 if (d == nullptr) return;
215 TF_Tensor* t = requireTensor(env, tensor_handle);
216 if (t == nullptr) return;
217 const char* cname = env->GetStringUTFChars(name, nullptr);
218 TF_Status* status = TF_NewStatus();
219 TF_SetAttrTensor(d, cname, t, status);
220 throwExceptionIfNotOK(env, status);
221 TF_DeleteStatus(status);
222 env->ReleaseStringUTFChars(name, cname);
223 }
224
225 JNIEXPORT void JNICALL
Java_org_tensorflow_GraphOperationBuilder_setAttrTensorList(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray tensor_handles)226 Java_org_tensorflow_GraphOperationBuilder_setAttrTensorList(
227 JNIEnv* env, jclass clazz, jlong handle, jstring name,
228 jlongArray tensor_handles) {
229 TF_OperationDescription* d = requireHandle(env, handle);
230 if (d == nullptr) return;
231 const int n = env->GetArrayLength(tensor_handles);
232 std::unique_ptr<TF_Tensor*[]> tensors(new TF_Tensor*[n]);
233 jlong* jhandles = env->GetLongArrayElements(tensor_handles, nullptr);
234 bool ok = true;
235 for (int i = 0; i < n && ok; ++i) {
236 tensors[i] = requireTensor(env, jhandles[i]);
237 ok = !env->ExceptionCheck();
238 }
239 env->ReleaseLongArrayElements(tensor_handles, jhandles, JNI_ABORT);
240 if (!ok) return;
241
242 const char* cname = env->GetStringUTFChars(name, nullptr);
243 TF_Status* status = TF_NewStatus();
244 TF_SetAttrTensorList(d, cname, tensors.get(), n, status);
245 throwExceptionIfNotOK(env, status);
246 TF_DeleteStatus(status);
247 env->ReleaseStringUTFChars(name, cname);
248 }
249
Java_org_tensorflow_GraphOperationBuilder_setAttrShape(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray shape,jint num_dims)250 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrShape(
251 JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shape,
252 jint num_dims) {
253 TF_OperationDescription* d = requireHandle(env, handle);
254 if (d == nullptr) return;
255 std::unique_ptr<int64_t[]> cvalue;
256 // num_dims and env->GetArrayLength(shape) are assumed to be consistent.
257 // i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
258 if (num_dims > 0) {
259 cvalue.reset(new int64_t[num_dims]);
260 jlong* elems = env->GetLongArrayElements(shape, nullptr);
261 for (int i = 0; i < num_dims; ++i) {
262 cvalue[i] = static_cast<int64_t>(elems[i]);
263 }
264 env->ReleaseLongArrayElements(shape, elems, JNI_ABORT);
265 }
266 const char* cname = env->GetStringUTFChars(name, nullptr);
267 TF_SetAttrShape(d, cname, cvalue.get(), static_cast<int>(num_dims));
268 env->ReleaseStringUTFChars(name, cname);
269 }
270
271 JNIEXPORT void JNICALL
Java_org_tensorflow_GraphOperationBuilder_setAttrShapeList(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray shapes,jintArray num_dims)272 Java_org_tensorflow_GraphOperationBuilder_setAttrShapeList(
273 JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shapes,
274 jintArray num_dims) {
275 TF_OperationDescription* d = requireHandle(env, handle);
276 if (d == nullptr) return;
277 std::unique_ptr<int64_t[]> cshapes;
278 std::unique_ptr<int64_t*[]> cdims;
279 std::unique_ptr<int[]> cnum_dims;
280 const int num_dims_length = env->GetArrayLength(num_dims);
281 if (num_dims_length > 0) {
282 const int shapes_length = env->GetArrayLength(shapes);
283 cshapes.reset(new int64_t[shapes_length]);
284 cdims.reset(new int64_t*[num_dims_length]);
285 cnum_dims.reset(new int[num_dims_length]);
286 jlong* shapes_elems =
287 static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr));
288 std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
289 env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
290 int64_t* cshapes_ptr = cshapes.get();
291 jint* num_dims_elems =
292 static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr));
293 for (int i = 0; i < num_dims_length; ++i) {
294 cnum_dims[i] = static_cast<int>(num_dims_elems[i]);
295 cdims[i] = cshapes_ptr;
296 if (cnum_dims[i] > 0) {
297 cshapes_ptr += cnum_dims[i];
298 }
299 }
300 env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT);
301 }
302 const char* cname = env->GetStringUTFChars(name, nullptr);
303 TF_SetAttrShapeList(d, cname, cdims.get(), cnum_dims.get(), num_dims_length);
304 env->ReleaseStringUTFChars(name, cname);
305 }
306
307 JNIEXPORT void JNICALL
Java_org_tensorflow_GraphOperationBuilder_setAttrStringList(JNIEnv * env,jclass object,jlong handle,jstring name,jobjectArray values)308 Java_org_tensorflow_GraphOperationBuilder_setAttrStringList(
309 JNIEnv* env, jclass object, jlong handle, jstring name,
310 jobjectArray values) {
311 TF_OperationDescription* d = requireHandle(env, handle);
312 if (d == nullptr) return;
313 const char* cname = env->GetStringUTFChars(name, nullptr);
314 int num_values = env->GetArrayLength(values);
315 static_assert(sizeof(jbyte) == 1,
316 "Require Java byte to be represented as a single byte");
317 std::unique_ptr<jbyteArray[]> jarrays(new jbyteArray[num_values]);
318 std::unique_ptr<jbyte*[]> jvalues(new jbyte*[num_values]);
319 std::unique_ptr<void*[]> cvalues(new void*[num_values]);
320 std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
321
322 for (int i = 0; i < num_values; ++i) {
323 jbyteArray v =
324 static_cast<jbyteArray>(env->GetObjectArrayElement(values, i));
325 jarrays[i] = v;
326 jvalues[i] = env->GetByteArrayElements(v, nullptr);
327 cvalues[i] = jvalues[i];
328 lengths[i] = static_cast<size_t>(env->GetArrayLength(v));
329 }
330 TF_SetAttrStringList(d, cname, cvalues.get(), lengths.get(), num_values);
331 for (int i = 0; i < num_values; ++i) {
332 env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT);
333 }
334 env->ReleaseStringUTFChars(name, cname);
335 }
336