• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/java/src/main/native/tensor_jni.h"
17 
18 #include <assert.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <algorithm>
22 #include <memory>
23 
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/java/src/main/native/exception_jni.h"
26 
27 namespace {
28 
requireHandle(JNIEnv * env,jlong handle)29 TF_Tensor* requireHandle(JNIEnv* env, jlong handle) {
30   if (handle == 0) {
31     throwException(env, kNullPointerException,
32                    "close() was called on the Tensor");
33     return nullptr;
34   }
35   return reinterpret_cast<TF_Tensor*>(handle);
36 }
37 
elemByteSize(TF_DataType dtype)38 size_t elemByteSize(TF_DataType dtype) {
39   // The code in this file makes the assumption that the
40   // TensorFlow TF_DataTypes and the Java primitive types
41   // have the same byte sizes. Validate that:
42   switch (dtype) {
43     case TF_BOOL:
44     case TF_UINT8:
45       static_assert(sizeof(jboolean) == 1,
46                     "Java boolean not compatible with TF_BOOL");
47       static_assert(sizeof(jbyte) == 1,
48                     "Java byte not compatible with TF_UINT8");
49       return 1;
50     case TF_FLOAT:
51     case TF_INT32:
52       static_assert(sizeof(jfloat) == 4,
53                     "Java float not compatible with TF_FLOAT");
54       static_assert(sizeof(jint) == 4, "Java int not compatible with TF_INT32");
55       return 4;
56     case TF_DOUBLE:
57     case TF_INT64:
58       static_assert(sizeof(jdouble) == 8,
59                     "Java double not compatible with TF_DOUBLE");
60       static_assert(sizeof(jlong) == 8,
61                     "Java long not compatible with TF_INT64");
62       return 8;
63     default:
64       return 0;
65   }
66 }
67 
68 // Write a Java scalar object (java.lang.Integer etc.) to a TF_Tensor.
writeScalar(JNIEnv * env,jobject src,TF_DataType dtype,void * dst,size_t dst_size)69 void writeScalar(JNIEnv* env, jobject src, TF_DataType dtype, void* dst,
70                  size_t dst_size) {
71   size_t sz = elemByteSize(dtype);
72   if (sz != dst_size) {
73     throwException(
74         env, kIllegalStateException,
75         "scalar (%d bytes) not compatible with allocated tensor (%d bytes)", sz,
76         dst_size);
77     return;
78   }
79   switch (dtype) {
80 // env->FindClass and env->GetMethodID are expensive and JNI best practices
81 // suggest that they should be cached. However, until the creation of scalar
82 // valued tensors seems to become a noticeable fraction of program execution,
83 // ignore that cost.
84 #define CASE(dtype, jtype, method_name, method_signature, call_type)           \
85   case dtype: {                                                                \
86     jclass clazz = env->FindClass("java/lang/Number");                         \
87     jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \
88     jtype v = env->Call##call_type##Method(src, method);                       \
89     memcpy(dst, &v, sz);                                                       \
90     return;                                                                    \
91   }
92     CASE(TF_FLOAT, jfloat, "floatValue", "()F", Float);
93     CASE(TF_DOUBLE, jdouble, "doubleValue", "()D", Double);
94     CASE(TF_INT32, jint, "intValue", "()I", Int);
95     CASE(TF_INT64, jlong, "longValue", "()J", Long);
96     CASE(TF_UINT8, jbyte, "byteValue", "()B", Byte);
97 #undef CASE
98     case TF_BOOL: {
99       jclass clazz = env->FindClass("java/lang/Boolean");
100       jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z");
101       jboolean v = env->CallBooleanMethod(src, method);
102       *(static_cast<unsigned char*>(dst)) = v ? 1 : 0;
103       return;
104     }
105     default:
106       throwException(env, kIllegalStateException, "invalid DataType(%d)",
107                      dtype);
108       return;
109   }
110 }
111 
112 // Copy a 1-D array of Java primitive types to the tensor buffer dst.
113 // Returns the number of bytes written to dst.
write1DArray(JNIEnv * env,jarray array,TF_DataType dtype,void * dst,size_t dst_size)114 size_t write1DArray(JNIEnv* env, jarray array, TF_DataType dtype, void* dst,
115                     size_t dst_size) {
116   const int nelems = env->GetArrayLength(array);
117   jboolean is_copy;
118   switch (dtype) {
119 #define CASE(dtype, jtype, get_type)                                   \
120   case dtype: {                                                        \
121     jtype##Array a = static_cast<jtype##Array>(array);                 \
122     jtype* values = env->Get##get_type##ArrayElements(a, &is_copy);    \
123     size_t to_copy = nelems * elemByteSize(dtype);                     \
124     if (to_copy > dst_size) {                                          \
125       throwException(                                                  \
126           env, kIllegalStateException,                                 \
127           "cannot write Java array of %d bytes to Tensor of %d bytes", \
128           to_copy, dst_size);                                          \
129       to_copy = 0;                                                     \
130     } else {                                                           \
131       memcpy(dst, values, to_copy);                                    \
132     }                                                                  \
133     env->Release##get_type##ArrayElements(a, values, JNI_ABORT);       \
134     return to_copy;                                                    \
135   }
136     CASE(TF_FLOAT, jfloat, Float);
137     CASE(TF_DOUBLE, jdouble, Double);
138     CASE(TF_INT32, jint, Int);
139     CASE(TF_INT64, jlong, Long);
140     CASE(TF_BOOL, jboolean, Boolean);
141     CASE(TF_UINT8, jbyte, Byte);
142 #undef CASE
143     default:
144       throwException(env, kIllegalStateException, "invalid DataType(%d)",
145                      dtype);
146       return 0;
147   }
148 }
149 
150 // Copy the elements of a 1-D array from the tensor buffer src to a 1-D array of
151 // Java primitive types. Returns the number of bytes read from src.
read1DArray(JNIEnv * env,TF_DataType dtype,const void * src,size_t src_size,jarray dst)152 size_t read1DArray(JNIEnv* env, TF_DataType dtype, const void* src,
153                    size_t src_size, jarray dst) {
154   const int len = env->GetArrayLength(dst);
155   const size_t sz = len * elemByteSize(dtype);
156   if (sz > src_size) {
157     throwException(
158         env, kIllegalStateException,
159         "cannot fill a Java array of %d bytes with a Tensor of %d bytes", sz,
160         src_size);
161     return 0;
162   }
163   switch (dtype) {
164 #define CASE(dtype, jtype, primitive_type)                                 \
165   case dtype: {                                                            \
166     jtype##Array arr = static_cast<jtype##Array>(dst);                     \
167     env->Set##primitive_type##ArrayRegion(arr, 0, len,                     \
168                                           static_cast<const jtype*>(src)); \
169     return sz;                                                             \
170   }
171     CASE(TF_FLOAT, jfloat, Float);
172     CASE(TF_DOUBLE, jdouble, Double);
173     CASE(TF_INT32, jint, Int);
174     CASE(TF_INT64, jlong, Long);
175     CASE(TF_BOOL, jboolean, Boolean);
176     CASE(TF_UINT8, jbyte, Byte);
177 #undef CASE
178     default:
179       throwException(env, kIllegalStateException, "invalid DataType(%d)",
180                      dtype);
181   }
182   return 0;
183 }
184 
writeNDArray(JNIEnv * env,jarray src,TF_DataType dtype,int dims_left,char * dst,size_t dst_size)185 size_t writeNDArray(JNIEnv* env, jarray src, TF_DataType dtype, int dims_left,
186                     char* dst, size_t dst_size) {
187   if (dims_left == 1) {
188     return write1DArray(env, src, dtype, dst, dst_size);
189   } else {
190     jobjectArray ndarray = static_cast<jobjectArray>(src);
191     int len = env->GetArrayLength(ndarray);
192     size_t sz = 0;
193     for (int i = 0; i < len; ++i) {
194       jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
195       sz +=
196           writeNDArray(env, row, dtype, dims_left - 1, dst + sz, dst_size - sz);
197       env->DeleteLocalRef(row);
198       if (env->ExceptionCheck()) return sz;
199     }
200     return sz;
201   }
202 }
203 
readNDArray(JNIEnv * env,TF_DataType dtype,const char * src,size_t src_size,int dims_left,jarray dst)204 size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
205                    size_t src_size, int dims_left, jarray dst) {
206   if (dims_left == 1) {
207     return read1DArray(env, dtype, src, src_size, dst);
208   } else {
209     jobjectArray ndarray = static_cast<jobjectArray>(dst);
210     int len = env->GetArrayLength(ndarray);
211     size_t sz = 0;
212     for (int i = 0; i < len; ++i) {
213       jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
214       sz +=
215           readNDArray(env, dtype, src + sz, src_size - sz, dims_left - 1, row);
216       env->DeleteLocalRef(row);
217       if (env->ExceptionCheck()) return sz;
218     }
219     return sz;
220   }
221 }
222 
TF_StringDecodeTojbyteArray(JNIEnv * env,const TF_TString * src)223 jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const TF_TString* src) {
224   const char* dst = TF_TString_GetDataPointer(src);
225   size_t dst_len = TF_TString_GetSize(src);
226 
227   jbyteArray ret = env->NewByteArray(dst_len);
228   jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
229 
230   memcpy(cpy, dst, dst_len);
231   env->ReleaseByteArrayElements(ret, cpy, 0);
232   return ret;
233 }
234 
235 class StringTensorWriter {
236  public:
StringTensorWriter(TF_Tensor * t,int num_elements)237   StringTensorWriter(TF_Tensor* t, int num_elements)
238       : index_(0), data_(static_cast<TF_TString*>(TF_TensorData(t))) {}
239 
Add(const char * src,size_t len,TF_Status * status)240   void Add(const char* src, size_t len, TF_Status* status) {
241     if (TF_GetCode(status) != TF_OK) return;
242     TF_TString_Init(&data_[index_]);
243     TF_TString_Copy(&data_[index_++], src, len);
244   }
245 
246  private:
247   int index_;
248   TF_TString* data_;
249 };
250 
251 class StringTensorReader {
252  public:
StringTensorReader(const TF_Tensor * t,int num_elements)253   StringTensorReader(const TF_Tensor* t, int num_elements)
254       : index_(0), data_(static_cast<const TF_TString*>(TF_TensorData(t))) {}
255 
Next(JNIEnv * env,TF_Status * status)256   jbyteArray Next(JNIEnv* env, TF_Status* status) {
257     if (TF_GetCode(status) != TF_OK) return nullptr;
258     return TF_StringDecodeTojbyteArray(env, &data_[index_++]);
259   }
260 
261  private:
262   int index_;
263   const TF_TString* data_;
264 };
265 
readNDStringArray(JNIEnv * env,StringTensorReader * reader,int dims_left,jobjectArray dst,TF_Status * status)266 void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
267                        jobjectArray dst, TF_Status* status) {
268   jsize len = env->GetArrayLength(dst);
269   if (dims_left == 1) {
270     for (jsize i = 0; i < len; ++i) {
271       jbyteArray elem = reader->Next(env, status);
272       if (TF_GetCode(status) != TF_OK) return;
273       env->SetObjectArrayElement(dst, i, elem);
274       env->DeleteLocalRef(elem);
275     }
276     return;
277   }
278   for (jsize i = 0; i < len; ++i) {
279     jobjectArray arr =
280         static_cast<jobjectArray>(env->GetObjectArrayElement(dst, i));
281     readNDStringArray(env, reader, dims_left - 1, arr, status);
282     env->DeleteLocalRef(arr);
283     if (TF_GetCode(status) != TF_OK) return;
284   }
285 }
286 }  // namespace
287 
Java_org_tensorflow_Tensor_allocate(JNIEnv * env,jclass clazz,jint dtype,jlongArray shape,jlong sizeInBytes)288 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
289                                                             jclass clazz,
290                                                             jint dtype,
291                                                             jlongArray shape,
292                                                             jlong sizeInBytes) {
293   int num_dims = static_cast<int>(env->GetArrayLength(shape));
294   jlong* dims = nullptr;
295   if (num_dims > 0) {
296     jboolean is_copy;
297     dims = env->GetLongArrayElements(shape, &is_copy);
298   }
299   static_assert(sizeof(jlong) == sizeof(int64_t),
300                 "Java long is not compatible with the TensorFlow C API");
301   // On some platforms "jlong" is a "long" while "int64_t" is a "long long".
302   //
303   // Thus, static_cast<int64_t*>(dims) will trigger a compiler error:
304   // static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
305   // *') is not allowed
306   //
307   // Since this array is typically very small, use the guaranteed safe scheme of
308   // creating a copy.
309   int64_t* dims_copy = new int64_t[num_dims];
310   for (int i = 0; i < num_dims; ++i) {
311     dims_copy[i] = static_cast<int64_t>(dims[i]);
312   }
313   TF_Tensor* t = TF_AllocateTensor(static_cast<TF_DataType>(dtype), dims_copy,
314                                    num_dims, static_cast<size_t>(sizeInBytes));
315   delete[] dims_copy;
316   if (dims != nullptr) {
317     env->ReleaseLongArrayElements(shape, dims, JNI_ABORT);
318   }
319   if (t == nullptr) {
320     throwException(env, kNullPointerException,
321                    "unable to allocate memory for the Tensor");
322     return 0;
323   }
324   return reinterpret_cast<jlong>(t);
325 }
326 
Java_org_tensorflow_Tensor_allocateScalarBytes(JNIEnv * env,jclass clazz,jbyteArray value)327 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
328     JNIEnv* env, jclass clazz, jbyteArray value) {
329   // TF_STRING tensors are encoded with a table of 8-byte offsets followed by
330   // TF_StringEncode-encoded bytes.
331   size_t src_len = static_cast<int>(env->GetArrayLength(value));
332   TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, sizeof(TF_TString));
333   TF_TString* dst = static_cast<TF_TString*>(TF_TensorData(t));
334 
335   TF_Status* status = TF_NewStatus();
336   jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
337   // jsrc is an unsigned byte*, TF_StringEncode requires a char*.
338   // reinterpret_cast<> for this conversion should be safe.
339   TF_TString_Init(&dst[0]);
340   TF_TString_Copy(&dst[0], reinterpret_cast<const char*>(jsrc), src_len);
341 
342   env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
343   if (!throwExceptionIfNotOK(env, status)) {
344     TF_DeleteStatus(status);
345     return 0;
346   }
347   TF_DeleteStatus(status);
348   return reinterpret_cast<jlong>(t);
349 }
350 
351 namespace {
checkForNullEntries(JNIEnv * env,jarray value,int num_dims)352 void checkForNullEntries(JNIEnv* env, jarray value, int num_dims) {
353   jsize len = env->GetArrayLength(value);
354   for (jsize i = 0; i < len; ++i) {
355     jarray elem = static_cast<jarray>(
356         env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
357     if (elem == nullptr) {
358       throwException(env, kNullPointerException,
359                      "null entries in provided array");
360       return;
361     }
362     env->DeleteLocalRef(elem);
363     if (env->ExceptionCheck()) return;
364   }
365 }
366 
fillNonScalarTF_STRINGTensorData(JNIEnv * env,jarray value,int num_dims,StringTensorWriter * writer,TF_Status * status)367 void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
368                                       StringTensorWriter* writer,
369                                       TF_Status* status) {
370   if (num_dims == 0) {
371     jbyte* jsrc =
372         env->GetByteArrayElements(static_cast<jbyteArray>(value), nullptr);
373     writer->Add(reinterpret_cast<const char*>(jsrc), env->GetArrayLength(value),
374                 status);
375     env->ReleaseByteArrayElements(static_cast<jbyteArray>(value), jsrc,
376                                   JNI_ABORT);
377     return;
378   }
379   jsize len = env->GetArrayLength(value);
380   for (jsize i = 0; i < len; ++i) {
381     jarray elem = static_cast<jarray>(
382         env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
383     fillNonScalarTF_STRINGTensorData(env, elem, num_dims - 1, writer, status);
384     env->DeleteLocalRef(elem);
385     if (TF_GetCode(status) != TF_OK) return;
386   }
387 }
388 }  // namespace
389 
Java_org_tensorflow_Tensor_allocateNonScalarBytes(JNIEnv * env,jclass clazz,jlongArray shape,jobjectArray value)390 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
391     JNIEnv* env, jclass clazz, jlongArray shape, jobjectArray value) {
392   // TF_STRING tensors are encoded with a table of 8-byte offsets following by
393   // TF_StringEncode-encoded bytes.
394   const int num_dims = static_cast<int>(env->GetArrayLength(shape));
395   int64_t* dims = new int64_t[num_dims];
396   int64_t num_elements = 1;
397   {
398     jlong* jdims = env->GetLongArrayElements(shape, nullptr);
399     for (int i = 0; i < num_dims; ++i) {
400       dims[i] = static_cast<int64_t>(jdims[i]);
401       num_elements *= dims[i];
402     }
403     env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
404   }
405   checkForNullEntries(env, value, num_dims);
406   if (env->ExceptionCheck()) return 0;
407   TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
408                                    sizeof(TF_TString) * num_elements);
409   if (t == nullptr) {
410     delete[] dims;
411     throwException(env, kNullPointerException,
412                    "unable to allocate memory for the Tensor");
413     return 0;
414   }
415   TF_Status* status = TF_NewStatus();
416   StringTensorWriter writer(t, num_elements);
417   fillNonScalarTF_STRINGTensorData(env, value, num_dims, &writer, status);
418   delete[] dims;
419   jlong ret = 0;
420   if (!throwExceptionIfNotOK(env, status)) {
421     TF_DeleteTensor(t);
422   } else {
423     ret = reinterpret_cast<jlong>(t);
424   }
425   TF_DeleteStatus(status);
426   return ret;
427 }
428 
Java_org_tensorflow_Tensor_delete(JNIEnv * env,jclass clazz,jlong handle)429 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
430                                                          jclass clazz,
431                                                          jlong handle) {
432   if (handle == 0) return;
433   TF_DeleteTensor(reinterpret_cast<TF_Tensor*>(handle));
434 }
435 
Java_org_tensorflow_Tensor_buffer(JNIEnv * env,jclass clazz,jlong handle)436 JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv* env,
437                                                             jclass clazz,
438                                                             jlong handle) {
439   TF_Tensor* t = requireHandle(env, handle);
440   if (t == nullptr) return nullptr;
441   void* data = TF_TensorData(t);
442   const size_t sz = TF_TensorByteSize(t);
443 
444   return env->NewDirectByteBuffer(data, static_cast<jlong>(sz));
445 }
446 
Java_org_tensorflow_Tensor_dtype(JNIEnv * env,jclass clazz,jlong handle)447 JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_dtype(JNIEnv* env,
448                                                         jclass clazz,
449                                                         jlong handle) {
450   static_assert(sizeof(jint) >= sizeof(TF_DataType),
451                 "TF_DataType in C cannot be represented as an int in Java");
452   TF_Tensor* t = requireHandle(env, handle);
453   if (t == nullptr) return 0;
454   return static_cast<jint>(TF_TensorType(t));
455 }
456 
Java_org_tensorflow_Tensor_shape(JNIEnv * env,jclass clazz,jlong handle)457 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Tensor_shape(JNIEnv* env,
458                                                               jclass clazz,
459                                                               jlong handle) {
460   TF_Tensor* t = requireHandle(env, handle);
461   if (t == nullptr) return nullptr;
462   static_assert(sizeof(jlong) == sizeof(int64_t),
463                 "Java long is not compatible with the TensorFlow C API");
464   const jsize num_dims = TF_NumDims(t);
465   jlongArray ret = env->NewLongArray(num_dims);
466   jlong* dims = env->GetLongArrayElements(ret, nullptr);
467   for (int i = 0; i < num_dims; ++i) {
468     dims[i] = static_cast<jlong>(TF_Dim(t, i));
469   }
470   env->ReleaseLongArrayElements(ret, dims, 0);
471   return ret;
472 }
473 
Java_org_tensorflow_Tensor_setValue(JNIEnv * env,jclass clazz,jlong handle,jobject value)474 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
475                                                            jclass clazz,
476                                                            jlong handle,
477                                                            jobject value) {
478   TF_Tensor* t = requireHandle(env, handle);
479   if (t == nullptr) return;
480   int num_dims = TF_NumDims(t);
481   TF_DataType dtype = TF_TensorType(t);
482   void* data = TF_TensorData(t);
483   const size_t sz = TF_TensorByteSize(t);
484   if (num_dims == 0) {
485     writeScalar(env, value, dtype, data, sz);
486   } else {
487     writeNDArray(env, static_cast<jarray>(value), dtype, num_dims,
488                  static_cast<char*>(data), sz);
489   }
490 }
491 
492 #define DEFINE_GET_SCALAR_METHOD(jtype, dtype, method_suffix)                  \
493   JNIEXPORT jtype JNICALL Java_org_tensorflow_Tensor_scalar##method_suffix(    \
494       JNIEnv* env, jclass clazz, jlong handle) {                               \
495     jtype ret = 0;                                                             \
496     TF_Tensor* t = requireHandle(env, handle);                                 \
497     if (t == nullptr) return ret;                                              \
498     if (TF_NumDims(t) != 0) {                                                  \
499       throwException(env, kIllegalStateException, "Tensor is not a scalar");   \
500     } else if (TF_TensorType(t) != dtype) {                                    \
501       throwException(env, kIllegalStateException, "Tensor is not a %s scalar", \
502                      #method_suffix);                                          \
503     } else {                                                                   \
504       memcpy(&ret, TF_TensorData(t), elemByteSize(dtype));                     \
505     }                                                                          \
506     return ret;                                                                \
507   }
508 DEFINE_GET_SCALAR_METHOD(jfloat, TF_FLOAT, Float);
509 DEFINE_GET_SCALAR_METHOD(jdouble, TF_DOUBLE, Double);
510 DEFINE_GET_SCALAR_METHOD(jint, TF_INT32, Int);
511 DEFINE_GET_SCALAR_METHOD(jlong, TF_INT64, Long);
512 DEFINE_GET_SCALAR_METHOD(jboolean, TF_BOOL, Boolean);
513 #undef DEFINE_GET_SCALAR_METHOD
514 
Java_org_tensorflow_Tensor_scalarBytes(JNIEnv * env,jclass clazz,jlong handle)515 JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(
516     JNIEnv* env, jclass clazz, jlong handle) {
517   TF_Tensor* t = requireHandle(env, handle);
518   if (t == nullptr) return nullptr;
519   if (TF_NumDims(t) != 0) {
520     throwException(env, kIllegalStateException, "Tensor is not a scalar");
521     return nullptr;
522   }
523   if (TF_TensorType(t) != TF_STRING) {
524     throwException(env, kIllegalArgumentException,
525                    "Tensor is not a string/bytes scalar");
526     return nullptr;
527   }
528   const TF_TString* data = static_cast<const TF_TString*>(TF_TensorData(t));
529   jbyteArray ret = TF_StringDecodeTojbyteArray(env, &data[0]);
530   return ret;
531 }
532 
Java_org_tensorflow_Tensor_readNDArray(JNIEnv * env,jclass clazz,jlong handle,jobject value)533 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env,
534                                                               jclass clazz,
535                                                               jlong handle,
536                                                               jobject value) {
537   TF_Tensor* t = requireHandle(env, handle);
538   if (t == nullptr) return;
539   int num_dims = TF_NumDims(t);
540   TF_DataType dtype = TF_TensorType(t);
541   const void* data = TF_TensorData(t);
542   const size_t sz = TF_TensorByteSize(t);
543   if (num_dims == 0) {
544     throwException(env, kIllegalArgumentException,
545                    "copyTo() is not meant for scalar Tensors, use the scalar "
546                    "accessor (floatValue(), intValue() etc.) instead");
547     return;
548   }
549   if (dtype == TF_STRING) {
550     int64_t num_elements = 1;
551     for (int i = 0; i < num_dims; ++i) {
552       num_elements *= TF_Dim(t, i);
553     }
554     StringTensorReader reader(t, num_elements);
555     TF_Status* status = TF_NewStatus();
556     readNDStringArray(env, &reader, num_dims, static_cast<jobjectArray>(value),
557                       status);
558     throwExceptionIfNotOK(env, status);
559     TF_DeleteStatus(status);
560     return;
561   }
562   readNDArray(env, dtype, static_cast<const char*>(data), sz, num_dims,
563               static_cast<jarray>(value));
564 }
565