1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/java/src/main/native/tensor_jni.h"
17
18 #include <assert.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <algorithm>
22 #include <memory>
23
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/java/src/main/native/exception_jni.h"
26
27 namespace {
28
requireHandle(JNIEnv * env,jlong handle)29 TF_Tensor* requireHandle(JNIEnv* env, jlong handle) {
30 if (handle == 0) {
31 throwException(env, kNullPointerException,
32 "close() was called on the Tensor");
33 return nullptr;
34 }
35 return reinterpret_cast<TF_Tensor*>(handle);
36 }
37
elemByteSize(TF_DataType dtype)38 size_t elemByteSize(TF_DataType dtype) {
39 // The code in this file makes the assumption that the
40 // TensorFlow TF_DataTypes and the Java primitive types
41 // have the same byte sizes. Validate that:
42 switch (dtype) {
43 case TF_BOOL:
44 case TF_UINT8:
45 static_assert(sizeof(jboolean) == 1,
46 "Java boolean not compatible with TF_BOOL");
47 static_assert(sizeof(jbyte) == 1,
48 "Java byte not compatible with TF_UINT8");
49 return 1;
50 case TF_FLOAT:
51 case TF_INT32:
52 static_assert(sizeof(jfloat) == 4,
53 "Java float not compatible with TF_FLOAT");
54 static_assert(sizeof(jint) == 4, "Java int not compatible with TF_INT32");
55 return 4;
56 case TF_DOUBLE:
57 case TF_INT64:
58 static_assert(sizeof(jdouble) == 8,
59 "Java double not compatible with TF_DOUBLE");
60 static_assert(sizeof(jlong) == 8,
61 "Java long not compatible with TF_INT64");
62 return 8;
63 default:
64 return 0;
65 }
66 }
67
68 // Write a Java scalar object (java.lang.Integer etc.) to a TF_Tensor.
writeScalar(JNIEnv * env,jobject src,TF_DataType dtype,void * dst,size_t dst_size)69 void writeScalar(JNIEnv* env, jobject src, TF_DataType dtype, void* dst,
70 size_t dst_size) {
71 size_t sz = elemByteSize(dtype);
72 if (sz != dst_size) {
73 throwException(
74 env, kIllegalStateException,
75 "scalar (%d bytes) not compatible with allocated tensor (%d bytes)", sz,
76 dst_size);
77 return;
78 }
79 switch (dtype) {
80 // env->FindClass and env->GetMethodID are expensive and JNI best practices
81 // suggest that they should be cached. However, until the creation of scalar
82 // valued tensors seems to become a noticeable fraction of program execution,
83 // ignore that cost.
84 #define CASE(dtype, jtype, method_name, method_signature, call_type) \
85 case dtype: { \
86 jclass clazz = env->FindClass("java/lang/Number"); \
87 jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \
88 jtype v = env->Call##call_type##Method(src, method); \
89 memcpy(dst, &v, sz); \
90 return; \
91 }
92 CASE(TF_FLOAT, jfloat, "floatValue", "()F", Float);
93 CASE(TF_DOUBLE, jdouble, "doubleValue", "()D", Double);
94 CASE(TF_INT32, jint, "intValue", "()I", Int);
95 CASE(TF_INT64, jlong, "longValue", "()J", Long);
96 CASE(TF_UINT8, jbyte, "byteValue", "()B", Byte);
97 #undef CASE
98 case TF_BOOL: {
99 jclass clazz = env->FindClass("java/lang/Boolean");
100 jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z");
101 jboolean v = env->CallBooleanMethod(src, method);
102 *(static_cast<unsigned char*>(dst)) = v ? 1 : 0;
103 return;
104 }
105 default:
106 throwException(env, kIllegalStateException, "invalid DataType(%d)",
107 dtype);
108 return;
109 }
110 }
111
112 // Copy a 1-D array of Java primitive types to the tensor buffer dst.
113 // Returns the number of bytes written to dst.
write1DArray(JNIEnv * env,jarray array,TF_DataType dtype,void * dst,size_t dst_size)114 size_t write1DArray(JNIEnv* env, jarray array, TF_DataType dtype, void* dst,
115 size_t dst_size) {
116 const int nelems = env->GetArrayLength(array);
117 jboolean is_copy;
118 switch (dtype) {
119 #define CASE(dtype, jtype, get_type) \
120 case dtype: { \
121 jtype##Array a = static_cast<jtype##Array>(array); \
122 jtype* values = env->Get##get_type##ArrayElements(a, &is_copy); \
123 size_t to_copy = nelems * elemByteSize(dtype); \
124 if (to_copy > dst_size) { \
125 throwException( \
126 env, kIllegalStateException, \
127 "cannot write Java array of %d bytes to Tensor of %d bytes", \
128 to_copy, dst_size); \
129 to_copy = 0; \
130 } else { \
131 memcpy(dst, values, to_copy); \
132 } \
133 env->Release##get_type##ArrayElements(a, values, JNI_ABORT); \
134 return to_copy; \
135 }
136 CASE(TF_FLOAT, jfloat, Float);
137 CASE(TF_DOUBLE, jdouble, Double);
138 CASE(TF_INT32, jint, Int);
139 CASE(TF_INT64, jlong, Long);
140 CASE(TF_BOOL, jboolean, Boolean);
141 CASE(TF_UINT8, jbyte, Byte);
142 #undef CASE
143 default:
144 throwException(env, kIllegalStateException, "invalid DataType(%d)",
145 dtype);
146 return 0;
147 }
148 }
149
150 // Copy the elements of a 1-D array from the tensor buffer src to a 1-D array of
151 // Java primitive types. Returns the number of bytes read from src.
read1DArray(JNIEnv * env,TF_DataType dtype,const void * src,size_t src_size,jarray dst)152 size_t read1DArray(JNIEnv* env, TF_DataType dtype, const void* src,
153 size_t src_size, jarray dst) {
154 const int len = env->GetArrayLength(dst);
155 const size_t sz = len * elemByteSize(dtype);
156 if (sz > src_size) {
157 throwException(
158 env, kIllegalStateException,
159 "cannot fill a Java array of %d bytes with a Tensor of %d bytes", sz,
160 src_size);
161 return 0;
162 }
163 switch (dtype) {
164 #define CASE(dtype, jtype, primitive_type) \
165 case dtype: { \
166 jtype##Array arr = static_cast<jtype##Array>(dst); \
167 env->Set##primitive_type##ArrayRegion(arr, 0, len, \
168 static_cast<const jtype*>(src)); \
169 return sz; \
170 }
171 CASE(TF_FLOAT, jfloat, Float);
172 CASE(TF_DOUBLE, jdouble, Double);
173 CASE(TF_INT32, jint, Int);
174 CASE(TF_INT64, jlong, Long);
175 CASE(TF_BOOL, jboolean, Boolean);
176 CASE(TF_UINT8, jbyte, Byte);
177 #undef CASE
178 default:
179 throwException(env, kIllegalStateException, "invalid DataType(%d)",
180 dtype);
181 }
182 return 0;
183 }
184
writeNDArray(JNIEnv * env,jarray src,TF_DataType dtype,int dims_left,char * dst,size_t dst_size)185 size_t writeNDArray(JNIEnv* env, jarray src, TF_DataType dtype, int dims_left,
186 char* dst, size_t dst_size) {
187 if (dims_left == 1) {
188 return write1DArray(env, src, dtype, dst, dst_size);
189 } else {
190 jobjectArray ndarray = static_cast<jobjectArray>(src);
191 int len = env->GetArrayLength(ndarray);
192 size_t sz = 0;
193 for (int i = 0; i < len; ++i) {
194 jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
195 sz +=
196 writeNDArray(env, row, dtype, dims_left - 1, dst + sz, dst_size - sz);
197 env->DeleteLocalRef(row);
198 if (env->ExceptionCheck()) return sz;
199 }
200 return sz;
201 }
202 }
203
readNDArray(JNIEnv * env,TF_DataType dtype,const char * src,size_t src_size,int dims_left,jarray dst)204 size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
205 size_t src_size, int dims_left, jarray dst) {
206 if (dims_left == 1) {
207 return read1DArray(env, dtype, src, src_size, dst);
208 } else {
209 jobjectArray ndarray = static_cast<jobjectArray>(dst);
210 int len = env->GetArrayLength(ndarray);
211 size_t sz = 0;
212 for (int i = 0; i < len; ++i) {
213 jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
214 sz +=
215 readNDArray(env, dtype, src + sz, src_size - sz, dims_left - 1, row);
216 env->DeleteLocalRef(row);
217 if (env->ExceptionCheck()) return sz;
218 }
219 return sz;
220 }
221 }
222
TF_StringDecodeTojbyteArray(JNIEnv * env,const char * src,size_t src_len,TF_Status * status)223 jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const char* src,
224 size_t src_len, TF_Status* status) {
225 const char* dst = nullptr;
226 size_t dst_len = 0;
227 TF_StringDecode(src, src_len, &dst, &dst_len, status);
228 if (TF_GetCode(status) != TF_OK) {
229 return nullptr;
230 }
231 jbyteArray ret = env->NewByteArray(dst_len);
232 jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
233 memcpy(cpy, dst, dst_len);
234 env->ReleaseByteArrayElements(ret, cpy, 0);
235 return ret;
236 }
237
238 class StringTensorWriter {
239 public:
StringTensorWriter(TF_Tensor * t,int num_elements)240 StringTensorWriter(TF_Tensor* t, int num_elements)
241 : offset_(0),
242 poffsets_(static_cast<char*>(TF_TensorData(t))),
243 pdata_(poffsets_ + 8 * num_elements),
244 plimit_(poffsets_ + TF_TensorByteSize(t)) {}
245
Add(const char * src,size_t len,TF_Status * status)246 void Add(const char* src, size_t len, TF_Status* status) {
247 if (TF_GetCode(status) != TF_OK) return;
248 if (plimit_ - poffsets_ < sizeof(offset_)) {
249 TF_SetStatus(status, TF_OUT_OF_RANGE,
250 "TF_STRING tensor encoding ran out of space for offsets, "
251 "this is likely a bug, please file an issue at "
252 "https://github.com/tensorflow/tensorflow/issues/new");
253 return;
254 }
255 memcpy(poffsets_, &offset_, sizeof(offset_));
256 size_t written =
257 TF_StringEncode(src, len, pdata_, (plimit_ - pdata_), status);
258 offset_ += written;
259 poffsets_ += 8;
260 pdata_ += written;
261 }
262
263 private:
264 uint64_t offset_;
265 char* poffsets_;
266 char* pdata_;
267 const char* plimit_;
268 };
269
270 class StringTensorReader {
271 public:
StringTensorReader(const TF_Tensor * t,int num_elements)272 StringTensorReader(const TF_Tensor* t, int num_elements)
273 : index_(0),
274 offsets_(static_cast<const char*>(TF_TensorData(t))),
275 data_(offsets_ + 8 * num_elements),
276 limit_(offsets_ + TF_TensorByteSize(t)) {}
277
Next(JNIEnv * env,TF_Status * status)278 jbyteArray Next(JNIEnv* env, TF_Status* status) {
279 if (TF_GetCode(status) != TF_OK) return nullptr;
280 uint64_t offset = 0;
281 const char* poffset = offsets_ + sizeof(offset) * index_;
282 if (poffset >= limit_) {
283 TF_SetStatus(
284 status, TF_INTERNAL,
285 "Invalid TF_STRING tensor, offsets table seems to be too small");
286 return nullptr;
287 }
288 memcpy(&offset, poffset, sizeof(offset));
289 const char* pdata = data_ + offset;
290 if (pdata >= limit_) {
291 TF_SetStatus(status, TF_INTERNAL,
292 "Invalid TF_STRING tensor, invalid entry in offset table");
293 return nullptr;
294 }
295 ++index_;
296 return TF_StringDecodeTojbyteArray(env, pdata, (limit_ - pdata), status);
297 }
298
299 private:
300 int index_;
301 const char* offsets_;
302 const char* data_;
303 const char* limit_;
304 };
305
readNDStringArray(JNIEnv * env,StringTensorReader * reader,int dims_left,jobjectArray dst,TF_Status * status)306 void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
307 jobjectArray dst, TF_Status* status) {
308 jsize len = env->GetArrayLength(dst);
309 if (dims_left == 1) {
310 for (jsize i = 0; i < len; ++i) {
311 jbyteArray elem = reader->Next(env, status);
312 if (TF_GetCode(status) != TF_OK) return;
313 env->SetObjectArrayElement(dst, i, elem);
314 }
315 return;
316 }
317 for (jsize i = 0; i < len; ++i) {
318 jobjectArray arr =
319 static_cast<jobjectArray>(env->GetObjectArrayElement(dst, i));
320 readNDStringArray(env, reader, dims_left - 1, arr, status);
321 if (TF_GetCode(status) != TF_OK) return;
322 }
323 }
324 } // namespace
325
Java_org_tensorflow_Tensor_allocate(JNIEnv * env,jclass clazz,jint dtype,jlongArray shape,jlong sizeInBytes)326 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
327 jclass clazz,
328 jint dtype,
329 jlongArray shape,
330 jlong sizeInBytes) {
331 int num_dims = static_cast<int>(env->GetArrayLength(shape));
332 jlong* dims = nullptr;
333 if (num_dims > 0) {
334 jboolean is_copy;
335 dims = env->GetLongArrayElements(shape, &is_copy);
336 }
337 static_assert(sizeof(jlong) == sizeof(int64_t),
338 "Java long is not compatible with the TensorFlow C API");
339 // On some platforms "jlong" is a "long" while "int64_t" is a "long long".
340 //
341 // Thus, static_cast<int64_t*>(dims) will trigger a compiler error:
342 // static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
343 // *') is not allowed
344 //
345 // Since this array is typically very small, use the guaranteed safe scheme of
346 // creating a copy.
347 int64_t* dims_copy = new int64_t[num_dims];
348 for (int i = 0; i < num_dims; ++i) {
349 dims_copy[i] = static_cast<int64_t>(dims[i]);
350 }
351 TF_Tensor* t = TF_AllocateTensor(static_cast<TF_DataType>(dtype), dims_copy,
352 num_dims, static_cast<size_t>(sizeInBytes));
353 delete[] dims_copy;
354 if (dims != nullptr) {
355 env->ReleaseLongArrayElements(shape, dims, JNI_ABORT);
356 }
357 if (t == nullptr) {
358 throwException(env, kNullPointerException,
359 "unable to allocate memory for the Tensor");
360 return 0;
361 }
362 return reinterpret_cast<jlong>(t);
363 }
364
Java_org_tensorflow_Tensor_allocateScalarBytes(JNIEnv * env,jclass clazz,jbyteArray value)365 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
366 JNIEnv* env, jclass clazz, jbyteArray value) {
367 // TF_STRING tensors are encoded with a table of 8-byte offsets followed by
368 // TF_StringEncode-encoded bytes.
369 size_t src_len = static_cast<int>(env->GetArrayLength(value));
370 size_t dst_len = TF_StringEncodedSize(src_len);
371 TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, 8 + dst_len);
372 char* dst = static_cast<char*>(TF_TensorData(t));
373 memset(dst, 0, 8); // The offset table
374
375 TF_Status* status = TF_NewStatus();
376 jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
377 // jsrc is an unsigned byte*, TF_StringEncode requires a char*.
378 // reinterpret_cast<> for this conversion should be safe.
379 TF_StringEncode(reinterpret_cast<const char*>(jsrc), src_len, dst + 8,
380 dst_len, status);
381 env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
382 if (!throwExceptionIfNotOK(env, status)) {
383 TF_DeleteStatus(status);
384 return 0;
385 }
386 TF_DeleteStatus(status);
387 return reinterpret_cast<jlong>(t);
388 }
389
390 namespace {
nonScalarTF_STRINGTensorSize(JNIEnv * env,jarray value,int num_dims)391 size_t nonScalarTF_STRINGTensorSize(JNIEnv* env, jarray value, int num_dims) {
392 if (num_dims == 0) {
393 // This is the last dimension, i.e., value should correspond to a jbyteArray
394 // encoding the string.
395 return TF_StringEncodedSize(
396 static_cast<size_t>(env->GetArrayLength(value)));
397 }
398 jsize len = env->GetArrayLength(value);
399 size_t ret = 0;
400 for (jsize i = 0; i < len; ++i) {
401 jarray elem = static_cast<jarray>(
402 env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
403 if (elem == nullptr) {
404 throwException(env, kNullPointerException,
405 "null entries in provided array");
406 return ret;
407 }
408 ret += nonScalarTF_STRINGTensorSize(env, elem, num_dims - 1);
409 if (env->ExceptionCheck()) return ret;
410 }
411 return ret;
412 }
413
fillNonScalarTF_STRINGTensorData(JNIEnv * env,jarray value,int num_dims,StringTensorWriter * writer,TF_Status * status)414 void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
415 StringTensorWriter* writer,
416 TF_Status* status) {
417 if (num_dims == 0) {
418 jbyte* jsrc =
419 env->GetByteArrayElements(static_cast<jbyteArray>(value), nullptr);
420 writer->Add(reinterpret_cast<const char*>(jsrc), env->GetArrayLength(value),
421 status);
422 env->ReleaseByteArrayElements(static_cast<jbyteArray>(value), jsrc,
423 JNI_ABORT);
424 return;
425 }
426 jsize len = env->GetArrayLength(value);
427 for (jsize i = 0; i < len; ++i) {
428 jarray elem = static_cast<jarray>(
429 env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
430 fillNonScalarTF_STRINGTensorData(env, elem, num_dims - 1, writer, status);
431 if (TF_GetCode(status) != TF_OK) return;
432 }
433 }
434 } // namespace
435
Java_org_tensorflow_Tensor_allocateNonScalarBytes(JNIEnv * env,jclass clazz,jlongArray shape,jobjectArray value)436 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
437 JNIEnv* env, jclass clazz, jlongArray shape, jobjectArray value) {
438 // TF_STRING tensors are encoded with a table of 8-byte offsets following by
439 // TF_StringEncode-encoded bytes.
440 const int num_dims = static_cast<int>(env->GetArrayLength(shape));
441 int64_t* dims = new int64_t[num_dims];
442 int64_t num_elements = 1;
443 {
444 jlong* jdims = env->GetLongArrayElements(shape, nullptr);
445 for (int i = 0; i < num_dims; ++i) {
446 dims[i] = static_cast<int64_t>(jdims[i]);
447 num_elements *= dims[i];
448 }
449 env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
450 }
451 const size_t encoded_size =
452 nonScalarTF_STRINGTensorSize(env, value, num_dims);
453 if (env->ExceptionCheck()) return 0;
454 TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
455 8 * num_elements + encoded_size);
456 if (t == nullptr) {
457 delete[] dims;
458 throwException(env, kNullPointerException,
459 "unable to allocate memory for the Tensor");
460 return 0;
461 }
462 TF_Status* status = TF_NewStatus();
463 StringTensorWriter writer(t, num_elements);
464 fillNonScalarTF_STRINGTensorData(env, value, num_dims, &writer, status);
465 delete[] dims;
466 jlong ret = 0;
467 if (!throwExceptionIfNotOK(env, status)) {
468 TF_DeleteTensor(t);
469 } else {
470 ret = reinterpret_cast<jlong>(t);
471 }
472 TF_DeleteStatus(status);
473 return ret;
474 }
475
Java_org_tensorflow_Tensor_delete(JNIEnv * env,jclass clazz,jlong handle)476 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
477 jclass clazz,
478 jlong handle) {
479 if (handle == 0) return;
480 TF_DeleteTensor(reinterpret_cast<TF_Tensor*>(handle));
481 }
482
Java_org_tensorflow_Tensor_buffer(JNIEnv * env,jclass clazz,jlong handle)483 JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv* env,
484 jclass clazz,
485 jlong handle) {
486 TF_Tensor* t = requireHandle(env, handle);
487 if (t == nullptr) return nullptr;
488 void* data = TF_TensorData(t);
489 const size_t sz = TF_TensorByteSize(t);
490
491 return env->NewDirectByteBuffer(data, static_cast<jlong>(sz));
492 }
493
Java_org_tensorflow_Tensor_dtype(JNIEnv * env,jclass clazz,jlong handle)494 JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_dtype(JNIEnv* env,
495 jclass clazz,
496 jlong handle) {
497 static_assert(sizeof(jint) >= sizeof(TF_DataType),
498 "TF_DataType in C cannot be represented as an int in Java");
499 TF_Tensor* t = requireHandle(env, handle);
500 if (t == nullptr) return 0;
501 return static_cast<jint>(TF_TensorType(t));
502 }
503
Java_org_tensorflow_Tensor_shape(JNIEnv * env,jclass clazz,jlong handle)504 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Tensor_shape(JNIEnv* env,
505 jclass clazz,
506 jlong handle) {
507 TF_Tensor* t = requireHandle(env, handle);
508 if (t == nullptr) return nullptr;
509 static_assert(sizeof(jlong) == sizeof(int64_t),
510 "Java long is not compatible with the TensorFlow C API");
511 const jsize num_dims = TF_NumDims(t);
512 jlongArray ret = env->NewLongArray(num_dims);
513 jlong* dims = env->GetLongArrayElements(ret, nullptr);
514 for (int i = 0; i < num_dims; ++i) {
515 dims[i] = static_cast<jlong>(TF_Dim(t, i));
516 }
517 env->ReleaseLongArrayElements(ret, dims, 0);
518 return ret;
519 }
520
Java_org_tensorflow_Tensor_setValue(JNIEnv * env,jclass clazz,jlong handle,jobject value)521 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
522 jclass clazz,
523 jlong handle,
524 jobject value) {
525 TF_Tensor* t = requireHandle(env, handle);
526 if (t == nullptr) return;
527 int num_dims = TF_NumDims(t);
528 TF_DataType dtype = TF_TensorType(t);
529 void* data = TF_TensorData(t);
530 const size_t sz = TF_TensorByteSize(t);
531 if (num_dims == 0) {
532 writeScalar(env, value, dtype, data, sz);
533 } else {
534 writeNDArray(env, static_cast<jarray>(value), dtype, num_dims,
535 static_cast<char*>(data), sz);
536 }
537 }
538
539 #define DEFINE_GET_SCALAR_METHOD(jtype, dtype, method_suffix) \
540 JNIEXPORT jtype JNICALL Java_org_tensorflow_Tensor_scalar##method_suffix( \
541 JNIEnv* env, jclass clazz, jlong handle) { \
542 jtype ret = 0; \
543 TF_Tensor* t = requireHandle(env, handle); \
544 if (t == nullptr) return ret; \
545 if (TF_NumDims(t) != 0) { \
546 throwException(env, kIllegalStateException, "Tensor is not a scalar"); \
547 } else if (TF_TensorType(t) != dtype) { \
548 throwException(env, kIllegalStateException, "Tensor is not a %s scalar", \
549 #method_suffix); \
550 } else { \
551 memcpy(&ret, TF_TensorData(t), elemByteSize(dtype)); \
552 } \
553 return ret; \
554 }
555 DEFINE_GET_SCALAR_METHOD(jfloat, TF_FLOAT, Float);
556 DEFINE_GET_SCALAR_METHOD(jdouble, TF_DOUBLE, Double);
557 DEFINE_GET_SCALAR_METHOD(jint, TF_INT32, Int);
558 DEFINE_GET_SCALAR_METHOD(jlong, TF_INT64, Long);
559 DEFINE_GET_SCALAR_METHOD(jboolean, TF_BOOL, Boolean);
560 #undef DEFINE_GET_SCALAR_METHOD
561
Java_org_tensorflow_Tensor_scalarBytes(JNIEnv * env,jclass clazz,jlong handle)562 JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(
563 JNIEnv* env, jclass clazz, jlong handle) {
564 TF_Tensor* t = requireHandle(env, handle);
565 if (t == nullptr) return nullptr;
566 if (TF_NumDims(t) != 0) {
567 throwException(env, kIllegalStateException, "Tensor is not a scalar");
568 return nullptr;
569 }
570 if (TF_TensorType(t) != TF_STRING) {
571 throwException(env, kIllegalArgumentException,
572 "Tensor is not a string/bytes scalar");
573 return nullptr;
574 }
575 const char* data = static_cast<const char*>(TF_TensorData(t));
576 const char* src = data + 8;
577 size_t src_len = TF_TensorByteSize(t) - 8;
578 uint64_t offset = 0;
579 memcpy(&offset, data, sizeof(offset));
580 if (offset >= src_len) {
581 throwException(env, kIllegalArgumentException,
582 "invalid tensor encoding: bad offsets");
583 return nullptr;
584 }
585 TF_Status* status = TF_NewStatus();
586 jbyteArray ret = TF_StringDecodeTojbyteArray(env, src, src_len, status);
587 throwExceptionIfNotOK(env, status);
588 TF_DeleteStatus(status);
589 return ret;
590 }
591
Java_org_tensorflow_Tensor_readNDArray(JNIEnv * env,jclass clazz,jlong handle,jobject value)592 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env,
593 jclass clazz,
594 jlong handle,
595 jobject value) {
596 TF_Tensor* t = requireHandle(env, handle);
597 if (t == nullptr) return;
598 int num_dims = TF_NumDims(t);
599 TF_DataType dtype = TF_TensorType(t);
600 const void* data = TF_TensorData(t);
601 const size_t sz = TF_TensorByteSize(t);
602 if (num_dims == 0) {
603 throwException(env, kIllegalArgumentException,
604 "copyTo() is not meant for scalar Tensors, use the scalar "
605 "accessor (floatValue(), intValue() etc.) instead");
606 return;
607 }
608 if (dtype == TF_STRING) {
609 int64_t num_elements = 1;
610 for (int i = 0; i < num_dims; ++i) {
611 num_elements *= TF_Dim(t, i);
612 }
613 StringTensorReader reader(t, num_elements);
614 TF_Status* status = TF_NewStatus();
615 readNDStringArray(env, &reader, num_dims, static_cast<jobjectArray>(value),
616 status);
617 throwExceptionIfNotOK(env, status);
618 TF_DeleteStatus(status);
619 return;
620 }
621 readNDArray(env, dtype, static_cast<const char*>(data), sz, num_dims,
622 static_cast<jarray>(value));
623 }
624