1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/java/src/main/native/tensor_jni.h"
17
18 #include <assert.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <algorithm>
22 #include <memory>
23
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/java/src/main/native/exception_jni.h"
26
27 namespace {
28
requireHandle(JNIEnv * env,jlong handle)29 TF_Tensor* requireHandle(JNIEnv* env, jlong handle) {
30 if (handle == 0) {
31 throwException(env, kNullPointerException,
32 "close() was called on the Tensor");
33 return nullptr;
34 }
35 return reinterpret_cast<TF_Tensor*>(handle);
36 }
37
elemByteSize(TF_DataType dtype)38 size_t elemByteSize(TF_DataType dtype) {
39 // The code in this file makes the assumption that the
40 // TensorFlow TF_DataTypes and the Java primitive types
41 // have the same byte sizes. Validate that:
42 switch (dtype) {
43 case TF_BOOL:
44 case TF_UINT8:
45 static_assert(sizeof(jboolean) == 1,
46 "Java boolean not compatible with TF_BOOL");
47 static_assert(sizeof(jbyte) == 1,
48 "Java byte not compatible with TF_UINT8");
49 return 1;
50 case TF_FLOAT:
51 case TF_INT32:
52 static_assert(sizeof(jfloat) == 4,
53 "Java float not compatible with TF_FLOAT");
54 static_assert(sizeof(jint) == 4, "Java int not compatible with TF_INT32");
55 return 4;
56 case TF_DOUBLE:
57 case TF_INT64:
58 static_assert(sizeof(jdouble) == 8,
59 "Java double not compatible with TF_DOUBLE");
60 static_assert(sizeof(jlong) == 8,
61 "Java long not compatible with TF_INT64");
62 return 8;
63 default:
64 return 0;
65 }
66 }
67
68 // Write a Java scalar object (java.lang.Integer etc.) to a TF_Tensor.
writeScalar(JNIEnv * env,jobject src,TF_DataType dtype,void * dst,size_t dst_size)69 void writeScalar(JNIEnv* env, jobject src, TF_DataType dtype, void* dst,
70 size_t dst_size) {
71 size_t sz = elemByteSize(dtype);
72 if (sz != dst_size) {
73 throwException(
74 env, kIllegalStateException,
75 "scalar (%d bytes) not compatible with allocated tensor (%d bytes)", sz,
76 dst_size);
77 return;
78 }
79 switch (dtype) {
80 // env->FindClass and env->GetMethodID are expensive and JNI best practices
81 // suggest that they should be cached. However, until the creation of scalar
82 // valued tensors seems to become a noticeable fraction of program execution,
83 // ignore that cost.
84 #define CASE(dtype, jtype, method_name, method_signature, call_type) \
85 case dtype: { \
86 jclass clazz = env->FindClass("java/lang/Number"); \
87 jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \
88 jtype v = env->Call##call_type##Method(src, method); \
89 memcpy(dst, &v, sz); \
90 return; \
91 }
92 CASE(TF_FLOAT, jfloat, "floatValue", "()F", Float);
93 CASE(TF_DOUBLE, jdouble, "doubleValue", "()D", Double);
94 CASE(TF_INT32, jint, "intValue", "()I", Int);
95 CASE(TF_INT64, jlong, "longValue", "()J", Long);
96 CASE(TF_UINT8, jbyte, "byteValue", "()B", Byte);
97 #undef CASE
98 case TF_BOOL: {
99 jclass clazz = env->FindClass("java/lang/Boolean");
100 jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z");
101 jboolean v = env->CallBooleanMethod(src, method);
102 *(static_cast<unsigned char*>(dst)) = v ? 1 : 0;
103 return;
104 }
105 default:
106 throwException(env, kIllegalStateException, "invalid DataType(%d)",
107 dtype);
108 return;
109 }
110 }
111
112 // Copy a 1-D array of Java primitive types to the tensor buffer dst.
113 // Returns the number of bytes written to dst.
write1DArray(JNIEnv * env,jarray array,TF_DataType dtype,void * dst,size_t dst_size)114 size_t write1DArray(JNIEnv* env, jarray array, TF_DataType dtype, void* dst,
115 size_t dst_size) {
116 const int nelems = env->GetArrayLength(array);
117 jboolean is_copy;
118 switch (dtype) {
119 #define CASE(dtype, jtype, get_type) \
120 case dtype: { \
121 jtype##Array a = static_cast<jtype##Array>(array); \
122 jtype* values = env->Get##get_type##ArrayElements(a, &is_copy); \
123 size_t to_copy = nelems * elemByteSize(dtype); \
124 if (to_copy > dst_size) { \
125 throwException( \
126 env, kIllegalStateException, \
127 "cannot write Java array of %d bytes to Tensor of %d bytes", \
128 to_copy, dst_size); \
129 to_copy = 0; \
130 } else { \
131 memcpy(dst, values, to_copy); \
132 } \
133 env->Release##get_type##ArrayElements(a, values, JNI_ABORT); \
134 return to_copy; \
135 }
136 CASE(TF_FLOAT, jfloat, Float);
137 CASE(TF_DOUBLE, jdouble, Double);
138 CASE(TF_INT32, jint, Int);
139 CASE(TF_INT64, jlong, Long);
140 CASE(TF_BOOL, jboolean, Boolean);
141 CASE(TF_UINT8, jbyte, Byte);
142 #undef CASE
143 default:
144 throwException(env, kIllegalStateException, "invalid DataType(%d)",
145 dtype);
146 return 0;
147 }
148 }
149
150 // Copy the elements of a 1-D array from the tensor buffer src to a 1-D array of
151 // Java primitive types. Returns the number of bytes read from src.
read1DArray(JNIEnv * env,TF_DataType dtype,const void * src,size_t src_size,jarray dst)152 size_t read1DArray(JNIEnv* env, TF_DataType dtype, const void* src,
153 size_t src_size, jarray dst) {
154 const int len = env->GetArrayLength(dst);
155 const size_t sz = len * elemByteSize(dtype);
156 if (sz > src_size) {
157 throwException(
158 env, kIllegalStateException,
159 "cannot fill a Java array of %d bytes with a Tensor of %d bytes", sz,
160 src_size);
161 return 0;
162 }
163 switch (dtype) {
164 #define CASE(dtype, jtype, primitive_type) \
165 case dtype: { \
166 jtype##Array arr = static_cast<jtype##Array>(dst); \
167 env->Set##primitive_type##ArrayRegion(arr, 0, len, \
168 static_cast<const jtype*>(src)); \
169 return sz; \
170 }
171 CASE(TF_FLOAT, jfloat, Float);
172 CASE(TF_DOUBLE, jdouble, Double);
173 CASE(TF_INT32, jint, Int);
174 CASE(TF_INT64, jlong, Long);
175 CASE(TF_BOOL, jboolean, Boolean);
176 CASE(TF_UINT8, jbyte, Byte);
177 #undef CASE
178 default:
179 throwException(env, kIllegalStateException, "invalid DataType(%d)",
180 dtype);
181 }
182 return 0;
183 }
184
writeNDArray(JNIEnv * env,jarray src,TF_DataType dtype,int dims_left,char * dst,size_t dst_size)185 size_t writeNDArray(JNIEnv* env, jarray src, TF_DataType dtype, int dims_left,
186 char* dst, size_t dst_size) {
187 if (dims_left == 1) {
188 return write1DArray(env, src, dtype, dst, dst_size);
189 } else {
190 jobjectArray ndarray = static_cast<jobjectArray>(src);
191 int len = env->GetArrayLength(ndarray);
192 size_t sz = 0;
193 for (int i = 0; i < len; ++i) {
194 jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
195 sz +=
196 writeNDArray(env, row, dtype, dims_left - 1, dst + sz, dst_size - sz);
197 env->DeleteLocalRef(row);
198 if (env->ExceptionCheck()) return sz;
199 }
200 return sz;
201 }
202 }
203
readNDArray(JNIEnv * env,TF_DataType dtype,const char * src,size_t src_size,int dims_left,jarray dst)204 size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
205 size_t src_size, int dims_left, jarray dst) {
206 if (dims_left == 1) {
207 return read1DArray(env, dtype, src, src_size, dst);
208 } else {
209 jobjectArray ndarray = static_cast<jobjectArray>(dst);
210 int len = env->GetArrayLength(ndarray);
211 size_t sz = 0;
212 for (int i = 0; i < len; ++i) {
213 jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
214 sz +=
215 readNDArray(env, dtype, src + sz, src_size - sz, dims_left - 1, row);
216 env->DeleteLocalRef(row);
217 if (env->ExceptionCheck()) return sz;
218 }
219 return sz;
220 }
221 }
222
TF_StringDecodeTojbyteArray(JNIEnv * env,const TF_TString * src)223 jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const TF_TString* src) {
224 const char* dst = TF_TString_GetDataPointer(src);
225 size_t dst_len = TF_TString_GetSize(src);
226
227 jbyteArray ret = env->NewByteArray(dst_len);
228 jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
229
230 memcpy(cpy, dst, dst_len);
231 env->ReleaseByteArrayElements(ret, cpy, 0);
232 return ret;
233 }
234
235 class StringTensorWriter {
236 public:
StringTensorWriter(TF_Tensor * t,int num_elements)237 StringTensorWriter(TF_Tensor* t, int num_elements)
238 : index_(0), data_(static_cast<TF_TString*>(TF_TensorData(t))) {}
239
Add(const char * src,size_t len,TF_Status * status)240 void Add(const char* src, size_t len, TF_Status* status) {
241 if (TF_GetCode(status) != TF_OK) return;
242 TF_TString_Init(&data_[index_]);
243 TF_TString_Copy(&data_[index_++], src, len);
244 }
245
246 private:
247 int index_;
248 TF_TString* data_;
249 };
250
251 class StringTensorReader {
252 public:
StringTensorReader(const TF_Tensor * t,int num_elements)253 StringTensorReader(const TF_Tensor* t, int num_elements)
254 : index_(0), data_(static_cast<const TF_TString*>(TF_TensorData(t))) {}
255
Next(JNIEnv * env,TF_Status * status)256 jbyteArray Next(JNIEnv* env, TF_Status* status) {
257 if (TF_GetCode(status) != TF_OK) return nullptr;
258 return TF_StringDecodeTojbyteArray(env, &data_[index_++]);
259 }
260
261 private:
262 int index_;
263 const TF_TString* data_;
264 };
265
readNDStringArray(JNIEnv * env,StringTensorReader * reader,int dims_left,jobjectArray dst,TF_Status * status)266 void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
267 jobjectArray dst, TF_Status* status) {
268 jsize len = env->GetArrayLength(dst);
269 if (dims_left == 1) {
270 for (jsize i = 0; i < len; ++i) {
271 jbyteArray elem = reader->Next(env, status);
272 if (TF_GetCode(status) != TF_OK) return;
273 env->SetObjectArrayElement(dst, i, elem);
274 env->DeleteLocalRef(elem);
275 }
276 return;
277 }
278 for (jsize i = 0; i < len; ++i) {
279 jobjectArray arr =
280 static_cast<jobjectArray>(env->GetObjectArrayElement(dst, i));
281 readNDStringArray(env, reader, dims_left - 1, arr, status);
282 env->DeleteLocalRef(arr);
283 if (TF_GetCode(status) != TF_OK) return;
284 }
285 }
286 } // namespace
287
Java_org_tensorflow_Tensor_allocate(JNIEnv * env,jclass clazz,jint dtype,jlongArray shape,jlong sizeInBytes)288 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
289 jclass clazz,
290 jint dtype,
291 jlongArray shape,
292 jlong sizeInBytes) {
293 int num_dims = static_cast<int>(env->GetArrayLength(shape));
294 jlong* dims = nullptr;
295 if (num_dims > 0) {
296 jboolean is_copy;
297 dims = env->GetLongArrayElements(shape, &is_copy);
298 }
299 static_assert(sizeof(jlong) == sizeof(int64_t),
300 "Java long is not compatible with the TensorFlow C API");
301 // On some platforms "jlong" is a "long" while "int64_t" is a "long long".
302 //
303 // Thus, static_cast<int64_t*>(dims) will trigger a compiler error:
304 // static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
305 // *') is not allowed
306 //
307 // Since this array is typically very small, use the guaranteed safe scheme of
308 // creating a copy.
309 int64_t* dims_copy = new int64_t[num_dims];
310 for (int i = 0; i < num_dims; ++i) {
311 dims_copy[i] = static_cast<int64_t>(dims[i]);
312 }
313 TF_Tensor* t = TF_AllocateTensor(static_cast<TF_DataType>(dtype), dims_copy,
314 num_dims, static_cast<size_t>(sizeInBytes));
315 delete[] dims_copy;
316 if (dims != nullptr) {
317 env->ReleaseLongArrayElements(shape, dims, JNI_ABORT);
318 }
319 if (t == nullptr) {
320 throwException(env, kNullPointerException,
321 "unable to allocate memory for the Tensor");
322 return 0;
323 }
324 return reinterpret_cast<jlong>(t);
325 }
326
Java_org_tensorflow_Tensor_allocateScalarBytes(JNIEnv * env,jclass clazz,jbyteArray value)327 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
328 JNIEnv* env, jclass clazz, jbyteArray value) {
329 // TF_STRING tensors are encoded with a table of 8-byte offsets followed by
330 // TF_StringEncode-encoded bytes.
331 size_t src_len = static_cast<int>(env->GetArrayLength(value));
332 TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, sizeof(TF_TString));
333 TF_TString* dst = static_cast<TF_TString*>(TF_TensorData(t));
334
335 TF_Status* status = TF_NewStatus();
336 jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
337 // jsrc is an unsigned byte*, TF_StringEncode requires a char*.
338 // reinterpret_cast<> for this conversion should be safe.
339 TF_TString_Init(&dst[0]);
340 TF_TString_Copy(&dst[0], reinterpret_cast<const char*>(jsrc), src_len);
341
342 env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
343 if (!throwExceptionIfNotOK(env, status)) {
344 TF_DeleteStatus(status);
345 return 0;
346 }
347 TF_DeleteStatus(status);
348 return reinterpret_cast<jlong>(t);
349 }
350
351 namespace {
checkForNullEntries(JNIEnv * env,jarray value,int num_dims)352 void checkForNullEntries(JNIEnv* env, jarray value, int num_dims) {
353 jsize len = env->GetArrayLength(value);
354 for (jsize i = 0; i < len; ++i) {
355 jarray elem = static_cast<jarray>(
356 env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
357 if (elem == nullptr) {
358 throwException(env, kNullPointerException,
359 "null entries in provided array");
360 return;
361 }
362 env->DeleteLocalRef(elem);
363 if (env->ExceptionCheck()) return;
364 }
365 }
366
fillNonScalarTF_STRINGTensorData(JNIEnv * env,jarray value,int num_dims,StringTensorWriter * writer,TF_Status * status)367 void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
368 StringTensorWriter* writer,
369 TF_Status* status) {
370 if (num_dims == 0) {
371 jbyte* jsrc =
372 env->GetByteArrayElements(static_cast<jbyteArray>(value), nullptr);
373 writer->Add(reinterpret_cast<const char*>(jsrc), env->GetArrayLength(value),
374 status);
375 env->ReleaseByteArrayElements(static_cast<jbyteArray>(value), jsrc,
376 JNI_ABORT);
377 return;
378 }
379 jsize len = env->GetArrayLength(value);
380 for (jsize i = 0; i < len; ++i) {
381 jarray elem = static_cast<jarray>(
382 env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
383 fillNonScalarTF_STRINGTensorData(env, elem, num_dims - 1, writer, status);
384 env->DeleteLocalRef(elem);
385 if (TF_GetCode(status) != TF_OK) return;
386 }
387 }
388 } // namespace
389
Java_org_tensorflow_Tensor_allocateNonScalarBytes(JNIEnv * env,jclass clazz,jlongArray shape,jobjectArray value)390 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
391 JNIEnv* env, jclass clazz, jlongArray shape, jobjectArray value) {
392 // TF_STRING tensors are encoded with a table of 8-byte offsets following by
393 // TF_StringEncode-encoded bytes.
394 const int num_dims = static_cast<int>(env->GetArrayLength(shape));
395 int64_t* dims = new int64_t[num_dims];
396 int64_t num_elements = 1;
397 {
398 jlong* jdims = env->GetLongArrayElements(shape, nullptr);
399 for (int i = 0; i < num_dims; ++i) {
400 dims[i] = static_cast<int64_t>(jdims[i]);
401 num_elements *= dims[i];
402 }
403 env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
404 }
405 checkForNullEntries(env, value, num_dims);
406 if (env->ExceptionCheck()) return 0;
407 TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
408 sizeof(TF_TString) * num_elements);
409 if (t == nullptr) {
410 delete[] dims;
411 throwException(env, kNullPointerException,
412 "unable to allocate memory for the Tensor");
413 return 0;
414 }
415 TF_Status* status = TF_NewStatus();
416 StringTensorWriter writer(t, num_elements);
417 fillNonScalarTF_STRINGTensorData(env, value, num_dims, &writer, status);
418 delete[] dims;
419 jlong ret = 0;
420 if (!throwExceptionIfNotOK(env, status)) {
421 TF_DeleteTensor(t);
422 } else {
423 ret = reinterpret_cast<jlong>(t);
424 }
425 TF_DeleteStatus(status);
426 return ret;
427 }
428
Java_org_tensorflow_Tensor_delete(JNIEnv * env,jclass clazz,jlong handle)429 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
430 jclass clazz,
431 jlong handle) {
432 if (handle == 0) return;
433 TF_DeleteTensor(reinterpret_cast<TF_Tensor*>(handle));
434 }
435
Java_org_tensorflow_Tensor_buffer(JNIEnv * env,jclass clazz,jlong handle)436 JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv* env,
437 jclass clazz,
438 jlong handle) {
439 TF_Tensor* t = requireHandle(env, handle);
440 if (t == nullptr) return nullptr;
441 void* data = TF_TensorData(t);
442 const size_t sz = TF_TensorByteSize(t);
443
444 return env->NewDirectByteBuffer(data, static_cast<jlong>(sz));
445 }
446
Java_org_tensorflow_Tensor_dtype(JNIEnv * env,jclass clazz,jlong handle)447 JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_dtype(JNIEnv* env,
448 jclass clazz,
449 jlong handle) {
450 static_assert(sizeof(jint) >= sizeof(TF_DataType),
451 "TF_DataType in C cannot be represented as an int in Java");
452 TF_Tensor* t = requireHandle(env, handle);
453 if (t == nullptr) return 0;
454 return static_cast<jint>(TF_TensorType(t));
455 }
456
Java_org_tensorflow_Tensor_shape(JNIEnv * env,jclass clazz,jlong handle)457 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Tensor_shape(JNIEnv* env,
458 jclass clazz,
459 jlong handle) {
460 TF_Tensor* t = requireHandle(env, handle);
461 if (t == nullptr) return nullptr;
462 static_assert(sizeof(jlong) == sizeof(int64_t),
463 "Java long is not compatible with the TensorFlow C API");
464 const jsize num_dims = TF_NumDims(t);
465 jlongArray ret = env->NewLongArray(num_dims);
466 jlong* dims = env->GetLongArrayElements(ret, nullptr);
467 for (int i = 0; i < num_dims; ++i) {
468 dims[i] = static_cast<jlong>(TF_Dim(t, i));
469 }
470 env->ReleaseLongArrayElements(ret, dims, 0);
471 return ret;
472 }
473
Java_org_tensorflow_Tensor_setValue(JNIEnv * env,jclass clazz,jlong handle,jobject value)474 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
475 jclass clazz,
476 jlong handle,
477 jobject value) {
478 TF_Tensor* t = requireHandle(env, handle);
479 if (t == nullptr) return;
480 int num_dims = TF_NumDims(t);
481 TF_DataType dtype = TF_TensorType(t);
482 void* data = TF_TensorData(t);
483 const size_t sz = TF_TensorByteSize(t);
484 if (num_dims == 0) {
485 writeScalar(env, value, dtype, data, sz);
486 } else {
487 writeNDArray(env, static_cast<jarray>(value), dtype, num_dims,
488 static_cast<char*>(data), sz);
489 }
490 }
491
492 #define DEFINE_GET_SCALAR_METHOD(jtype, dtype, method_suffix) \
493 JNIEXPORT jtype JNICALL Java_org_tensorflow_Tensor_scalar##method_suffix( \
494 JNIEnv* env, jclass clazz, jlong handle) { \
495 jtype ret = 0; \
496 TF_Tensor* t = requireHandle(env, handle); \
497 if (t == nullptr) return ret; \
498 if (TF_NumDims(t) != 0) { \
499 throwException(env, kIllegalStateException, "Tensor is not a scalar"); \
500 } else if (TF_TensorType(t) != dtype) { \
501 throwException(env, kIllegalStateException, "Tensor is not a %s scalar", \
502 #method_suffix); \
503 } else { \
504 memcpy(&ret, TF_TensorData(t), elemByteSize(dtype)); \
505 } \
506 return ret; \
507 }
508 DEFINE_GET_SCALAR_METHOD(jfloat, TF_FLOAT, Float);
509 DEFINE_GET_SCALAR_METHOD(jdouble, TF_DOUBLE, Double);
510 DEFINE_GET_SCALAR_METHOD(jint, TF_INT32, Int);
511 DEFINE_GET_SCALAR_METHOD(jlong, TF_INT64, Long);
512 DEFINE_GET_SCALAR_METHOD(jboolean, TF_BOOL, Boolean);
513 #undef DEFINE_GET_SCALAR_METHOD
514
Java_org_tensorflow_Tensor_scalarBytes(JNIEnv * env,jclass clazz,jlong handle)515 JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(
516 JNIEnv* env, jclass clazz, jlong handle) {
517 TF_Tensor* t = requireHandle(env, handle);
518 if (t == nullptr) return nullptr;
519 if (TF_NumDims(t) != 0) {
520 throwException(env, kIllegalStateException, "Tensor is not a scalar");
521 return nullptr;
522 }
523 if (TF_TensorType(t) != TF_STRING) {
524 throwException(env, kIllegalArgumentException,
525 "Tensor is not a string/bytes scalar");
526 return nullptr;
527 }
528 const TF_TString* data = static_cast<const TF_TString*>(TF_TensorData(t));
529 jbyteArray ret = TF_StringDecodeTojbyteArray(env, &data[0]);
530 return ret;
531 }
532
Java_org_tensorflow_Tensor_readNDArray(JNIEnv * env,jclass clazz,jlong handle,jobject value)533 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env,
534 jclass clazz,
535 jlong handle,
536 jobject value) {
537 TF_Tensor* t = requireHandle(env, handle);
538 if (t == nullptr) return;
539 int num_dims = TF_NumDims(t);
540 TF_DataType dtype = TF_TensorType(t);
541 const void* data = TF_TensorData(t);
542 const size_t sz = TF_TensorByteSize(t);
543 if (num_dims == 0) {
544 throwException(env, kIllegalArgumentException,
545 "copyTo() is not meant for scalar Tensors, use the scalar "
546 "accessor (floatValue(), intValue() etc.) instead");
547 return;
548 }
549 if (dtype == TF_STRING) {
550 int64_t num_elements = 1;
551 for (int i = 0; i < num_dims; ++i) {
552 num_elements *= TF_Dim(t, i);
553 }
554 StringTensorReader reader(t, num_elements);
555 TF_Status* status = TF_NewStatus();
556 readNDStringArray(env, &reader, num_dims, static_cast<jobjectArray>(value),
557 status);
558 throwExceptionIfNotOK(env, status);
559 TF_DeleteStatus(status);
560 return;
561 }
562 readNDArray(env, dtype, static_cast<const char*>(data), sz, num_dims,
563 static_cast<jarray>(value));
564 }
565