• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.mindspore;
18 
19 import static com.mindspore.config.MindsporeLite.POINTER_DEFAULT_VALUE;
20 
21 import com.mindspore.config.MindsporeLite;
22 import com.mindspore.config.DataType;
23 
24 import java.nio.ByteBuffer;
25 import java.nio.LongBuffer;
26 import java.nio.FloatBuffer;
27 import java.lang.reflect.Array;
28 import java.util.HashMap;
29 import java.util.logging.Logger;
30 
31 /**
32  * The MSTensor class defines a tensor in MindSpore.
33  *
34  * @since v1.0
35  */
36 public class MSTensor {
37     private static final Logger LOGGER = Logger.getLogger(MSTensor.class.toString());
38 
39     static {
MindsporeLite.init()40         MindsporeLite.init();
41     }
42 
43     private long tensorPtr;
44     private Object buffer;
45 
46     /**
47      * MSTensor construct function.
48      */
MSTensor()49     public MSTensor() {
50         this.tensorPtr = POINTER_DEFAULT_VALUE;
51         this.buffer = null;
52     }
53 
54     /**
55      * MSTensor construct function.
56      *
57      * @param tensorPtr tensor pointer.
58      */
MSTensor(long tensorPtr)59     public MSTensor(long tensorPtr) {
60         this.tensorPtr = tensorPtr;
61         this.buffer = null;
62     }
63 
64     /**
65      * MSTensor construct function.
66      *
67      * @param tensorName tensor name
68      * @param buffer     tensor buffer
69      */
createTensor(String tensorName, int dataType, int[] tensorShape, ByteBuffer buffer)70     public static MSTensor createTensor(String tensorName, int dataType, int[] tensorShape, ByteBuffer buffer) {
71         if (tensorName == null || tensorShape == null || buffer == null || dataType < DataType.kNumberTypeBool ||
72             dataType > DataType.kNumberTypeFloat64) {
73             LOGGER.severe("input params null.");
74             return null;
75         }
76         long tensorPtr = createTensorByNative(tensorName, dataType, tensorShape, buffer);
77         return new MSTensor(tensorPtr);
78     }
79 
80     /**
81      * MSTensor construct function.
82      *
83      * @param tensorName tensor name
84      * @param obj        java Array or a Scalar. Support dtype: float, double, int, long, boolean.
85      */
createTensor(String tensorName, Object obj)86     public static MSTensor createTensor(String tensorName, Object obj) {
87         if (tensorName == null || obj == null) {
88             LOGGER.severe("input params null.");
89             return null;
90         }
91         int dType = ParseDataType(obj);
92         if (dType == 0) {
93             LOGGER.severe("input param dtype invalid.");
94             return null;
95         }
96         int[] shape = ParseShape(obj);
97         if (shape == null) {
98             LOGGER.severe("input param shape null.");
99             return null;
100         }
101         long tensorPtr = createTensorByObject(tensorName, dType, shape, obj);
102         return new MSTensor(tensorPtr);
103     }
104 
105     /**
106      * Get the shape of the MindSpore MSTensor.
107      *
108      * @return A array of int as the shape of the MindSpore MSTensor.
109      */
getShape()110     public int[] getShape() {
111         return this.getShape(this.tensorPtr);
112     }
113 
114     /**
115      * DataType is defined in com.mindspore.DataType.
116      *
117      * @return The MindSpore data type of the MindSpore MSTensor class.
118      */
getDataType()119     public int getDataType() {
120         return this.getDataType(this.tensorPtr);
121     }
122 
123     /**
124      * Get output data of MSTensor, data type is the same as the type data is set.
125      *
126      * @return The byte array containing all MSTensor output data.
127      */
getData()128     public Object getData() {
129         Object ret = null;
130         if (this.buffer != null) {
131             return this.buffer;
132         } else {
133             int dataType = this.getDataType();
134             switch (dataType) {
135                 case DataType.kNumberTypeFloat32:
136                     ret = this.getFloatData(this.tensorPtr);
137                     break;
138                 case DataType.kNumberTypeFloat16:
139                     ret = this.getFloat16Data(this.tensorPtr);
140                     break;
141                 case DataType.kNumberTypeInt32:
142                     ret = this.getIntData(this.tensorPtr);
143                     break;
144                 case DataType.kNumberTypeInt64:
145                     ret = this.getLongData(this.tensorPtr);
146                     break;
147                 default:
148                     LOGGER.warning("Do not support data type: " + dataType + ", would return byte[] data");
149                     ret = this.getByteData(this.tensorPtr);
150             }
151         }
152         return ret;
153     }
154 
155     /**
156      * Get output data of MSTensor, the data type is byte.
157      *
158      * @return The byte array containing all MSTensor output data.
159      */
getByteData()160     public byte[] getByteData() {
161         if (this.buffer == null) {
162             return this.getByteData(this.tensorPtr);
163         }
164         if (this.buffer instanceof byte[]) {
165             return (byte[]) this.buffer;
166         }
167         return new byte[0];
168     }
169 
170     /**
171      * Get output data of MSTensor, the data type is float.
172      *
173      * @return The float array containing all MSTensor output data.
174      */
getFloatData()175     public float[] getFloatData() {
176         if (this.buffer == null) {
177             if (this.getDataType() == DataType.kNumberTypeFloat16) {
178                 return this.getFloat16Data(this.tensorPtr);
179             }
180             return this.getFloatData(this.tensorPtr);
181         }
182         if (this.buffer instanceof float[]) {
183             return (float[]) this.buffer;
184         }
185         int dataType = this.getDataType();
186         float[] floatArray = new float[0];
187         if (this.buffer instanceof byte[]
188             && (dataType == DataType.kNumberTypeFloat16 || dataType == DataType.kNumberTypeFloat32)) {
189             ByteBuffer byteBuffer = ByteBuffer.wrap((byte[]) this.buffer);
190             FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
191             floatArray = new float[floatBuffer.remaining()];
192             floatBuffer.get(floatArray);
193         }
194         return floatArray;
195     }
196 
197     /**
198      * Get output data of MSTensor, the data type is int.
199      *
200      * @return The int array containing all MSTensor output data.
201      */
getIntData()202     public int[] getIntData() {
203         if (this.buffer == null) {
204             return this.getIntData(this.tensorPtr);
205         }
206         if (this.buffer instanceof int[]) {
207             return (int[]) this.buffer;
208         }
209         int dataType = this.getDataType();
210         int[] intArray = new int[0];
211         if (this.buffer instanceof byte[]
212             && (dataType == DataType.kNumberTypeInt32)) {
213             byte[] byteArray = (byte[]) this.buffer;
214             intArray = new int[byteArray.length];
215             for (int i = 0; i < byteArray.length; i++) {
216                 intArray[i] = byteArray[i] & 0xff;
217             }
218         }
219         return intArray;
220     }
221 
222     /**
223      * Get output data of MSTensor, the data type is long.
224      *
225      * @return The long array containing all MSTensor output data.
226      */
getLongData()227     public long[] getLongData() {
228         if (this.buffer == null) {
229             return this.getLongData(this.tensorPtr);
230         }
231         if (this.buffer instanceof long[]) {
232             return (long[]) this.buffer;
233         }
234         int dataType = this.getDataType();
235         long[] longArray = new long[0];
236         if (this.buffer instanceof byte[]
237             && (dataType == DataType.kNumberTypeFloat16 || dataType == DataType.kNumberTypeFloat32)) {
238             ByteBuffer byteBuffer = ByteBuffer.wrap((byte[]) this.buffer);
239             LongBuffer longBuffer = byteBuffer.asLongBuffer();
240             longArray = new long[longBuffer.remaining()];
241             longBuffer.get(longArray);
242         }
243         return longArray;
244     }
245 
246     /**
247      * Set the shape of MSTensor.
248      *
249      * @param tensorShape of int[] type.
250      * @return whether set shape success.
251      */
setShape(int[] tensorShape)252     public boolean setShape(int[] tensorShape) {
253         if (tensorShape == null) {
254             LOGGER.severe("input param null.");
255             return false;
256         }
257         return this.setShape(this.tensorPtr, tensorShape);
258     }
259 
260     /**
261      * Set the input data of MSTensor.
262      *
263      * @param data Input data of ByteBuffer type.
264      * @return whether set data success.
265      */
setData(ByteBuffer data)266     public boolean setData(ByteBuffer data) {
267         if (data == null) {
268             LOGGER.severe("input param null.");
269             return false;
270         }
271         return this.setByteBufferData(this.tensorPtr, data);
272     }
273 
274     /**
275      * Set the input data of MSTensor.
276      *
277      * @param data Input data of byte[] type.
278      * @return whether set data success.
279      */
setData(byte[] data)280     public boolean setData(byte[] data) {
281         if (data == null) {
282             LOGGER.severe("input param null.");
283             return false;
284         }
285         if (data.length != this.size()) {
286             return false;
287         }
288         this.buffer = data;
289         return true;
290     }
291 
292     /**
293      * Set the input data of MSTensor.
294      *
295      * @param data Input data of float[] type.
296      * @return whether set data success.
297      */
setData(float[] data)298     public boolean setData(float[] data) {
299         if (data == null) {
300             LOGGER.severe("input param null.");
301             return false;
302         }
303         if (this.getDataType() != DataType.kNumberTypeFloat32
304             && this.getDataType() != DataType.kNumberTypeFloat16) {
305             LOGGER.severe("Data type is not consistent");
306             return false;
307         }
308         if (data.length != this.elementsNum()) {
309             return false;
310         }
311         this.buffer = data;
312         return true;
313     }
314 
315     /**
316      * Set the input data of MSTensor.
317      *
318      * @param data Input data of int[] type.
319      * @return whether set data success.
320      */
setData(int[] data)321     public boolean setData(int[] data) {
322         if (data == null) {
323             LOGGER.severe("input param null.");
324             return false;
325         }
326         if (this.getDataType() != DataType.kNumberTypeInt32) {
327             LOGGER.severe("Data type is not consistent");
328             return false;
329         }
330         if (data.length != this.elementsNum()) {
331             return false;
332         }
333         this.buffer = data;
334         return true;
335     }
336 
337     /**
338      * Set the input data of MSTensor.
339      *
340      * @param data Input data of long[] type.
341      * @return whether set data success.
342      */
setData(long[] data)343     public boolean setData(long[] data) {
344         if (data == null) {
345             LOGGER.severe("input param null.");
346             return false;
347         }
348         if (this.getDataType() != DataType.kNumberTypeInt64) {
349             LOGGER.severe("Data type is not consistent");
350             return false;
351         }
352         if (data.length != this.elementsNum()) {
353             return false;
354         }
355         this.buffer = data;
356         return true;
357     }
358 
359     /**
360      * Get the size of the data in MSTensor in bytes.
361      *
362      * @return The size of the data in MSTensor in bytes.
363      */
size()364     public long size() {
365         return this.size(this.tensorPtr);
366     }
367 
368     /**
369      * Get the number of elements in MSTensor.
370      *
371      * @return The number of elements in MSTensor.
372      */
elementsNum()373     public int elementsNum() {
374         return this.elementsNum(this.tensorPtr);
375     }
376 
377     /**
378      * Free all temporary memory in MindSpore MSTensor.
379      */
free()380     public void free() {
381         this.free(this.tensorPtr);
382         this.tensorPtr = POINTER_DEFAULT_VALUE;
383         this.buffer = null;
384     }
385 
386     /**
387      * @return Get tensor name
388      */
tensorName()389     public String tensorName() {
390         return this.tensorName(this.tensorPtr);
391     }
392 
393     /**
394      * @return MSTensor pointer
395      */
getMSTensorPtr()396     public long getMSTensorPtr() {
397         return tensorPtr;
398     }
399 
ParseDataType(Object obj)400     private static int ParseDataType(Object obj) {
401         HashMap<Class<?>, Integer> classToDType = new HashMap<Class<?>, Integer>() {{
402             put(float.class, DataType.kNumberTypeFloat32);
403             put(Float.class, DataType.kNumberTypeFloat32);
404             put(double.class, DataType.kNumberTypeFloat64);
405             put(Double.class, DataType.kNumberTypeFloat64);
406             put(int.class, DataType.kNumberTypeInt32);
407             put(Integer.class, DataType.kNumberTypeInt32);
408             put(long.class, DataType.kNumberTypeInt64);
409             put(Long.class, DataType.kNumberTypeInt64);
410             put(boolean.class, DataType.kNumberTypeBool);
411             put(Boolean.class, DataType.kNumberTypeBool);
412         }};
413         Class<?> c = obj.getClass();
414         while (c.isArray()) {
415             c = c.getComponentType();
416         }
417         Integer dType = classToDType.get(c);
418         return dType == null ? 0 : dType;
419     }
420 
ParseShape(Object obj)421     private static int[] ParseShape(Object obj) {
422         int i = 0;
423         Class<?> c = obj.getClass();
424         while (c.isArray()) {
425             c = c.getComponentType();
426             ++i;
427         }
428         int[] shape = new int[i];
429         i = 0;
430         c = obj.getClass();
431         while (c.isArray()) {
432             shape[i] = Array.getLength(obj);
433             if (shape[i] <= 0) {
434                 return null;
435             }
436             obj = Array.get(obj, 0);
437             c = c.getComponentType();
438             ++i;
439         }
440         return shape;
441     }
442 
createTensorByNative(String tensorName, int dataType, int[] tesorShape, ByteBuffer buffer)443     private static native long createTensorByNative(String tensorName, int dataType, int[] tesorShape,
444                                                     ByteBuffer buffer);
445 
createTensorByObject(String tensorName, int dataType, int[] tesorShape, Object obj)446     private static native long createTensorByObject(String tensorName, int dataType, int[] tesorShape,
447                                                     Object obj);
448 
getShape(long tensorPtr)449     private native int[] getShape(long tensorPtr);
450 
getDataType(long tensorPtr)451     private native int getDataType(long tensorPtr);
452 
getByteData(long tensorPtr)453     private native byte[] getByteData(long tensorPtr);
454 
getLongData(long tensorPtr)455     private native long[] getLongData(long tensorPtr);
456 
getIntData(long tensorPtr)457     private native int[] getIntData(long tensorPtr);
458 
getFloatData(long tensorPtr)459     private native float[] getFloatData(long tensorPtr);
460 
getFloat16Data(long tensorPtr)461     private native float[] getFloat16Data(long tensorPtr);
462 
setByteData(long tensorPtr, byte[] data, long dataLen)463     private native boolean setByteData(long tensorPtr, byte[] data, long dataLen);
464 
setFloatData(long tensorPtr, float[] data, long dataLen)465     private native boolean setFloatData(long tensorPtr, float[] data, long dataLen);
466 
setIntData(long tensorPtr, int[] data, long dataLen)467     private native boolean setIntData(long tensorPtr, int[] data, long dataLen);
468 
setLongData(long tensorPtr, long[] data, long dataLen)469     private native boolean setLongData(long tensorPtr, long[] data, long dataLen);
470 
setShape(long tensorPtr, int[] tensorShape)471     private native boolean setShape(long tensorPtr, int[] tensorShape);
472 
setByteBufferData(long tensorPtr, ByteBuffer buffer)473     private native boolean setByteBufferData(long tensorPtr, ByteBuffer buffer);
474 
size(long tensorPtr)475     private native long size(long tensorPtr);
476 
elementsNum(long tensorPtr)477     private native int elementsNum(long tensorPtr);
478 
free(long tensorPtr)479     private native void free(long tensorPtr);
480 
tensorName(long tensorPtr)481     private native String tensorName(long tensorPtr);
482 }