• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/java/src/main/native/graph_operation_builder_jni.h"
17 #include <cstring>
18 #include <memory>
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/java/src/main/native/exception_jni.h"
21 
22 namespace {
requireHandle(JNIEnv * env,jlong handle)23 TF_OperationDescription* requireHandle(JNIEnv* env, jlong handle) {
24   if (handle == 0) {
25     throwException(env, kIllegalStateException,
26                    "Operation has already been built");
27     return nullptr;
28   }
29   return reinterpret_cast<TF_OperationDescription*>(handle);
30 }
31 
resolveOutput(JNIEnv * env,jlong op_handle,jint index,TF_Output * out)32 bool resolveOutput(JNIEnv* env, jlong op_handle, jint index, TF_Output* out) {
33   if (op_handle == 0) {
34     throwException(env, kIllegalStateException,
35                    "close() was called on the Graph");
36     return false;
37   }
38   out->oper = reinterpret_cast<TF_Operation*>(op_handle);
39   out->index = static_cast<int>(index);
40   return true;
41 }
42 
requireTensor(JNIEnv * env,jlong handle)43 TF_Tensor* requireTensor(JNIEnv* env, jlong handle) {
44   if (handle == 0) {
45     throwException(env, kIllegalStateException,
46                    "close() has been called on the Tensor");
47     return nullptr;
48   }
49   return reinterpret_cast<TF_Tensor*>(handle);
50 }
51 }  // namespace
52 
Java_org_tensorflow_GraphOperationBuilder_allocate(JNIEnv * env,jclass clazz,jlong graph_handle,jstring type,jstring name)53 JNIEXPORT jlong JNICALL Java_org_tensorflow_GraphOperationBuilder_allocate(
54     JNIEnv* env, jclass clazz, jlong graph_handle, jstring type, jstring name) {
55   if (graph_handle == 0) {
56     throwException(env, kIllegalStateException,
57                    "close() has been called on the Graph");
58     return 0;
59   }
60   TF_Graph* graph = reinterpret_cast<TF_Graph*>(graph_handle);
61   const char* op_type = env->GetStringUTFChars(type, nullptr);
62   const char* op_name = env->GetStringUTFChars(name, nullptr);
63   TF_OperationDescription* d = TF_NewOperation(graph, op_type, op_name);
64   env->ReleaseStringUTFChars(name, op_name);
65   env->ReleaseStringUTFChars(type, op_type);
66   static_assert(sizeof(jlong) >= sizeof(TF_OperationDescription*),
67                 "Cannot represent a C TF_OperationDescription as a Java long");
68   return reinterpret_cast<jlong>(d);
69 }
70 
Java_org_tensorflow_GraphOperationBuilder_finish(JNIEnv * env,jclass clazz,jlong handle)71 JNIEXPORT jlong JNICALL Java_org_tensorflow_GraphOperationBuilder_finish(
72     JNIEnv* env, jclass clazz, jlong handle) {
73   TF_OperationDescription* d = requireHandle(env, handle);
74   if (d == nullptr) return 0;
75   TF_Status* status = TF_NewStatus();
76   TF_Operation* op = TF_FinishOperation(d, status);
77   if (throwExceptionIfNotOK(env, status)) {
78     TF_DeleteStatus(status);
79     return reinterpret_cast<jlong>(op);
80   }
81   TF_DeleteStatus(status);
82   return 0;
83 }
84 
Java_org_tensorflow_GraphOperationBuilder_addInput(JNIEnv * env,jclass clazz,jlong handle,jlong op_handle,jint index)85 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInput(
86     JNIEnv* env, jclass clazz, jlong handle, jlong op_handle, jint index) {
87   TF_Output out;
88   if (!resolveOutput(env, op_handle, index, &out)) return;
89   TF_OperationDescription* d = requireHandle(env, handle);
90   if (d == nullptr) return;
91   TF_AddInput(d, out);
92 }
93 
Java_org_tensorflow_GraphOperationBuilder_addInputList(JNIEnv * env,jclass clazz,jlong handle,jlongArray op_handles,jintArray indices)94 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_addInputList(
95     JNIEnv* env, jclass clazz, jlong handle, jlongArray op_handles,
96     jintArray indices) {
97   TF_OperationDescription* d = requireHandle(env, handle);
98   if (d == nullptr) return;
99   const size_t n = static_cast<size_t>(env->GetArrayLength(op_handles));
100   if (env->GetArrayLength(indices) != n) {
101     throwException(env, kIllegalArgumentException,
102                    "mismatch in number of Operations (%d) and output indices "
103                    "(%d) provided",
104                    n, env->GetArrayLength(indices));
105     return;
106   }
107   std::unique_ptr<TF_Output[]> o(new TF_Output[n]);
108   jlong* oph = env->GetLongArrayElements(op_handles, nullptr);
109   jint* idx = env->GetIntArrayElements(indices, nullptr);
110   bool ok = true;
111   for (int i = 0; i < n && ok; ++i) {
112     ok = resolveOutput(env, oph[i], idx[i], &o[i]);
113   }
114   env->ReleaseIntArrayElements(indices, idx, JNI_ABORT);
115   env->ReleaseLongArrayElements(op_handles, oph, JNI_ABORT);
116   if (!ok) return;
117   TF_AddInputList(d, o.get(), n);
118 }
119 
120 JNIEXPORT void JNICALL
Java_org_tensorflow_GraphOperationBuilder_addControlInput(JNIEnv * env,jclass clazz,jlong handle,jlong op_handle)121 Java_org_tensorflow_GraphOperationBuilder_addControlInput(JNIEnv* env,
122                                                           jclass clazz,
123                                                           jlong handle,
124                                                           jlong op_handle) {
125   if (op_handle == 0) {
126     throwException(env, kIllegalStateException,
127                    "control input is not valid, "
128                    "perhaps the Graph containing it has been closed()?");
129     return;
130   }
131   TF_Operation* control = reinterpret_cast<TF_Operation*>(op_handle);
132   TF_OperationDescription* d = requireHandle(env, handle);
133   if (d == nullptr) return;
134   TF_AddControlInput(d, control);
135 }
136 
Java_org_tensorflow_GraphOperationBuilder_setDevice(JNIEnv * env,jclass clazz,jlong handle,jstring device)137 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setDevice(
138     JNIEnv* env, jclass clazz, jlong handle, jstring device) {
139   TF_OperationDescription* d = requireHandle(env, handle);
140   if (d == nullptr) return;
141   const char* cdevice = env->GetStringUTFChars(device, nullptr);
142   TF_SetDevice(d, cdevice);
143   env->ReleaseStringUTFChars(device, cdevice);
144 }
145 
Java_org_tensorflow_GraphOperationBuilder_setAttrString(JNIEnv * env,jclass clazz,jlong handle,jstring name,jbyteArray value)146 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrString(
147     JNIEnv* env, jclass clazz, jlong handle, jstring name, jbyteArray value) {
148   static_assert(sizeof(jbyte) == 1,
149                 "Require Java byte to be represented as a single byte");
150   TF_OperationDescription* d = requireHandle(env, handle);
151   if (d == nullptr) return;
152   const char* cname = env->GetStringUTFChars(name, nullptr);
153   jbyte* cvalue = env->GetByteArrayElements(value, nullptr);
154   TF_SetAttrString(d, cname, cvalue, env->GetArrayLength(value));
155   env->ReleaseByteArrayElements(value, cvalue, JNI_ABORT);
156   env->ReleaseStringUTFChars(name, cname);
157 }
158 
159 #define DEFINE_SET_ATTR_SCALAR(name, jtype, ctype)                    \
160   JNIEXPORT void JNICALL                                              \
161       Java_org_tensorflow_GraphOperationBuilder_setAttr##name(        \
162           JNIEnv* env, jclass clazz, jlong handle, jstring name,      \
163           jtype value) {                                              \
164     static_assert(                                                    \
165         sizeof(ctype) >= sizeof(jtype),                               \
166         "Information loss when converting between Java and C types"); \
167     TF_OperationDescription* d = requireHandle(env, handle);          \
168     if (d == nullptr) return;                                         \
169     const char* cname = env->GetStringUTFChars(name, nullptr);        \
170     TF_SetAttr##name(d, cname, static_cast<ctype>(value));            \
171     env->ReleaseStringUTFChars(name, cname);                          \
172   }
173 
174 #define DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)              \
175   JNIEXPORT void JNICALL                                             \
176       Java_org_tensorflow_GraphOperationBuilder_setAttr##name##List( \
177           JNIEnv* env, jclass clazz, jlong handle, jstring name,     \
178           jtype##Array value) {                                      \
179     TF_OperationDescription* d = requireHandle(env, handle);         \
180     if (d == nullptr) return;                                        \
181     const char* cname = env->GetStringUTFChars(name, nullptr);       \
182     /* Make a copy of the array to paper over any differences */     \
183     /* in byte representations of the jtype and ctype         */     \
184     /* For example, jint vs TF_DataType.                      */     \
185     /* If this copy turns out to be a problem in practice     */     \
186     /* can avoid it for many types.                           */     \
187     const int n = env->GetArrayLength(value);                        \
188     std::unique_ptr<ctype[]> cvalue(new ctype[n]);                   \
189     jtype* elems = env->Get##jname##ArrayElements(value, nullptr);   \
190     for (int i = 0; i < n; ++i) {                                    \
191       cvalue[i] = static_cast<ctype>(elems[i]);                      \
192     }                                                                \
193     TF_SetAttr##name##List(d, cname, cvalue.get(), n);               \
194     env->Release##jname##ArrayElements(value, elems, JNI_ABORT);     \
195     env->ReleaseStringUTFChars(name, cname);                         \
196   }
197 
198 #define DEFINE_SET_ATTR(name, jname, jtype, ctype) \
199   DEFINE_SET_ATTR_SCALAR(name, jtype, ctype)       \
200   DEFINE_SET_ATTR_LIST(name, jname, jtype, ctype)
201 
202 DEFINE_SET_ATTR(Int, Long, jlong, int64_t);
203 DEFINE_SET_ATTR(Float, Float, jfloat, float);
204 DEFINE_SET_ATTR(Bool, Boolean, jboolean, unsigned char);
205 DEFINE_SET_ATTR(Type, Int, jint, TF_DataType);
206 #undef DEFINE_SET_ATTR
207 #undef DEFINE_SET_ATTR_LIST
208 #undef DEFINE_SET_ATTR_SCALAR
209 
Java_org_tensorflow_GraphOperationBuilder_setAttrTensor(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlong tensor_handle)210 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrTensor(
211     JNIEnv* env, jclass clazz, jlong handle, jstring name,
212     jlong tensor_handle) {
213   TF_OperationDescription* d = requireHandle(env, handle);
214   if (d == nullptr) return;
215   TF_Tensor* t = requireTensor(env, tensor_handle);
216   if (t == nullptr) return;
217   const char* cname = env->GetStringUTFChars(name, nullptr);
218   TF_Status* status = TF_NewStatus();
219   TF_SetAttrTensor(d, cname, t, status);
220   throwExceptionIfNotOK(env, status);
221   TF_DeleteStatus(status);
222   env->ReleaseStringUTFChars(name, cname);
223 }
224 
225 JNIEXPORT void JNICALL
Java_org_tensorflow_GraphOperationBuilder_setAttrTensorList(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray tensor_handles)226 Java_org_tensorflow_GraphOperationBuilder_setAttrTensorList(
227     JNIEnv* env, jclass clazz, jlong handle, jstring name,
228     jlongArray tensor_handles) {
229   TF_OperationDescription* d = requireHandle(env, handle);
230   if (d == nullptr) return;
231   const int n = env->GetArrayLength(tensor_handles);
232   std::unique_ptr<TF_Tensor*[]> tensors(new TF_Tensor*[n]);
233   jlong* jhandles = env->GetLongArrayElements(tensor_handles, nullptr);
234   bool ok = true;
235   for (int i = 0; i < n && ok; ++i) {
236     tensors[i] = requireTensor(env, jhandles[i]);
237     ok = !env->ExceptionCheck();
238   }
239   env->ReleaseLongArrayElements(tensor_handles, jhandles, JNI_ABORT);
240   if (!ok) return;
241 
242   const char* cname = env->GetStringUTFChars(name, nullptr);
243   TF_Status* status = TF_NewStatus();
244   TF_SetAttrTensorList(d, cname, tensors.get(), n, status);
245   throwExceptionIfNotOK(env, status);
246   TF_DeleteStatus(status);
247   env->ReleaseStringUTFChars(name, cname);
248 }
249 
Java_org_tensorflow_GraphOperationBuilder_setAttrShape(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray shape,jint num_dims)250 JNIEXPORT void JNICALL Java_org_tensorflow_GraphOperationBuilder_setAttrShape(
251     JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shape,
252     jint num_dims) {
253   TF_OperationDescription* d = requireHandle(env, handle);
254   if (d == nullptr) return;
255   std::unique_ptr<int64_t[]> cvalue;
256   // num_dims and env->GetArrayLength(shape) are assumed to be consistent.
257   // i.e., either num_dims < 0 or num_dims == env->GetArrayLength(shape).
258   if (num_dims > 0) {
259     cvalue.reset(new int64_t[num_dims]);
260     jlong* elems = env->GetLongArrayElements(shape, nullptr);
261     for (int i = 0; i < num_dims; ++i) {
262       cvalue[i] = static_cast<int64_t>(elems[i]);
263     }
264     env->ReleaseLongArrayElements(shape, elems, JNI_ABORT);
265   }
266   const char* cname = env->GetStringUTFChars(name, nullptr);
267   TF_SetAttrShape(d, cname, cvalue.get(), static_cast<int>(num_dims));
268   env->ReleaseStringUTFChars(name, cname);
269 }
270 
271 JNIEXPORT void JNICALL
Java_org_tensorflow_GraphOperationBuilder_setAttrShapeList(JNIEnv * env,jclass clazz,jlong handle,jstring name,jlongArray shapes,jintArray num_dims)272 Java_org_tensorflow_GraphOperationBuilder_setAttrShapeList(
273     JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shapes,
274     jintArray num_dims) {
275   TF_OperationDescription* d = requireHandle(env, handle);
276   if (d == nullptr) return;
277   std::unique_ptr<int64_t[]> cshapes;
278   std::unique_ptr<int64_t*[]> cdims;
279   std::unique_ptr<int[]> cnum_dims;
280   const int num_dims_length = env->GetArrayLength(num_dims);
281   if (num_dims_length > 0) {
282     const int shapes_length = env->GetArrayLength(shapes);
283     cshapes.reset(new int64_t[shapes_length]);
284     cdims.reset(new int64_t*[num_dims_length]);
285     cnum_dims.reset(new int[num_dims_length]);
286     jlong* shapes_elems =
287         static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr));
288     std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
289     env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
290     int64_t* cshapes_ptr = cshapes.get();
291     jint* num_dims_elems =
292         static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr));
293     for (int i = 0; i < num_dims_length; ++i) {
294       cnum_dims[i] = static_cast<int>(num_dims_elems[i]);
295       cdims[i] = cshapes_ptr;
296       if (cnum_dims[i] > 0) {
297         cshapes_ptr += cnum_dims[i];
298       }
299     }
300     env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT);
301   }
302   const char* cname = env->GetStringUTFChars(name, nullptr);
303   TF_SetAttrShapeList(d, cname, cdims.get(), cnum_dims.get(), num_dims_length);
304   env->ReleaseStringUTFChars(name, cname);
305 }
306 
307 JNIEXPORT void JNICALL
Java_org_tensorflow_GraphOperationBuilder_setAttrStringList(JNIEnv * env,jclass object,jlong handle,jstring name,jobjectArray values)308 Java_org_tensorflow_GraphOperationBuilder_setAttrStringList(
309     JNIEnv* env, jclass object, jlong handle, jstring name,
310     jobjectArray values) {
311   TF_OperationDescription* d = requireHandle(env, handle);
312   if (d == nullptr) return;
313   const char* cname = env->GetStringUTFChars(name, nullptr);
314   int num_values = env->GetArrayLength(values);
315   static_assert(sizeof(jbyte) == 1,
316                 "Require Java byte to be represented as a single byte");
317   std::unique_ptr<jbyteArray[]> jarrays(new jbyteArray[num_values]);
318   std::unique_ptr<jbyte*[]> jvalues(new jbyte*[num_values]);
319   std::unique_ptr<void*[]> cvalues(new void*[num_values]);
320   std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
321 
322   for (int i = 0; i < num_values; ++i) {
323     jbyteArray v =
324         static_cast<jbyteArray>(env->GetObjectArrayElement(values, i));
325     jarrays[i] = v;
326     jvalues[i] = env->GetByteArrayElements(v, nullptr);
327     cvalues[i] = jvalues[i];
328     lengths[i] = static_cast<size_t>(env->GetArrayLength(v));
329   }
330   TF_SetAttrStringList(d, cname, cvalues.get(), lengths.get(), num_values);
331   for (int i = 0; i < num_values; ++i) {
332     env->ReleaseByteArrayElements(jarrays[i], jvalues[i], JNI_ABORT);
333   }
334   env->ReleaseStringUTFChars(name, cname);
335 }
336