1 /* 2 * Copyright 2021-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.mindspore; 18 19 import static com.mindspore.config.MindsporeLite.POINTER_DEFAULT_VALUE; 20 21 import com.mindspore.config.MindsporeLite; 22 import com.mindspore.config.DataType; 23 24 import java.nio.ByteBuffer; 25 import java.nio.LongBuffer; 26 import java.nio.FloatBuffer; 27 import java.lang.reflect.Array; 28 import java.util.HashMap; 29 import java.util.logging.Logger; 30 31 /** 32 * The MSTensor class defines a tensor in MindSpore. 33 * 34 * @since v1.0 35 */ 36 public class MSTensor { 37 private static final Logger LOGGER = Logger.getLogger(MSTensor.class.toString()); 38 39 static { MindsporeLite.init()40 MindsporeLite.init(); 41 } 42 43 private long tensorPtr; 44 private Object buffer; 45 46 /** 47 * MSTensor construct function. 48 */ MSTensor()49 public MSTensor() { 50 this.tensorPtr = POINTER_DEFAULT_VALUE; 51 this.buffer = null; 52 } 53 54 /** 55 * MSTensor construct function. 56 * 57 * @param tensorPtr tensor pointer. 58 */ MSTensor(long tensorPtr)59 public MSTensor(long tensorPtr) { 60 this.tensorPtr = tensorPtr; 61 this.buffer = null; 62 } 63 64 /** 65 * MSTensor construct function. 66 * 67 * @param tensorName tensor name 68 * @param buffer tensor buffer 69 */ createTensor(String tensorName, int dataType, int[] tensorShape, ByteBuffer buffer)70 public static MSTensor createTensor(String tensorName, int dataType, int[] tensorShape, ByteBuffer buffer) { 71 if (tensorName == null || tensorShape == null || buffer == null || dataType < DataType.kNumberTypeBool || 72 dataType > DataType.kNumberTypeFloat64) { 73 LOGGER.severe("input params null."); 74 return null; 75 } 76 long tensorPtr = createTensorByNative(tensorName, dataType, tensorShape, buffer); 77 return new MSTensor(tensorPtr); 78 } 79 80 /** 81 * MSTensor construct function. 82 * 83 * @param tensorName tensor name 84 * @param obj java Array or a Scalar. Support dtype: float, double, int, long, boolean. 85 */ createTensor(String tensorName, Object obj)86 public static MSTensor createTensor(String tensorName, Object obj) { 87 if (tensorName == null || obj == null) { 88 LOGGER.severe("input params null."); 89 return null; 90 } 91 int dType = ParseDataType(obj); 92 if (dType == 0) { 93 LOGGER.severe("input param dtype invalid."); 94 return null; 95 } 96 int[] shape = ParseShape(obj); 97 if (shape == null) { 98 LOGGER.severe("input param shape null."); 99 return null; 100 } 101 long tensorPtr = createTensorByObject(tensorName, dType, shape, obj); 102 return new MSTensor(tensorPtr); 103 } 104 105 /** 106 * Get the shape of the MindSpore MSTensor. 107 * 108 * @return A array of int as the shape of the MindSpore MSTensor. 109 */ getShape()110 public int[] getShape() { 111 return this.getShape(this.tensorPtr); 112 } 113 114 /** 115 * DataType is defined in com.mindspore.DataType. 116 * 117 * @return The MindSpore data type of the MindSpore MSTensor class. 118 */ getDataType()119 public int getDataType() { 120 return this.getDataType(this.tensorPtr); 121 } 122 123 /** 124 * Get output data of MSTensor, data type is the same as the type data is set. 125 * 126 * @return The byte array containing all MSTensor output data. 127 */ getData()128 public Object getData() { 129 Object ret = null; 130 if (this.buffer != null) { 131 return this.buffer; 132 } else { 133 int dataType = this.getDataType(); 134 switch (dataType) { 135 case DataType.kNumberTypeFloat32: 136 ret = this.getFloatData(this.tensorPtr); 137 break; 138 case DataType.kNumberTypeFloat16: 139 ret = this.getFloat16Data(this.tensorPtr); 140 break; 141 case DataType.kNumberTypeInt32: 142 ret = this.getIntData(this.tensorPtr); 143 break; 144 case DataType.kNumberTypeInt64: 145 ret = this.getLongData(this.tensorPtr); 146 break; 147 default: 148 LOGGER.warning("Do not support data type: " + dataType + ", would return byte[] data"); 149 ret = this.getByteData(this.tensorPtr); 150 } 151 } 152 return ret; 153 } 154 155 /** 156 * Get output data of MSTensor, the data type is byte. 157 * 158 * @return The byte array containing all MSTensor output data. 159 */ getByteData()160 public byte[] getByteData() { 161 if (this.buffer == null) { 162 return this.getByteData(this.tensorPtr); 163 } 164 if (this.buffer instanceof byte[]) { 165 return (byte[]) this.buffer; 166 } 167 return new byte[0]; 168 } 169 170 /** 171 * Get output data of MSTensor, the data type is float. 172 * 173 * @return The float array containing all MSTensor output data. 174 */ getFloatData()175 public float[] getFloatData() { 176 if (this.buffer == null) { 177 if (this.getDataType() == DataType.kNumberTypeFloat16) { 178 return this.getFloat16Data(this.tensorPtr); 179 } 180 return this.getFloatData(this.tensorPtr); 181 } 182 if (this.buffer instanceof float[]) { 183 return (float[]) this.buffer; 184 } 185 int dataType = this.getDataType(); 186 float[] floatArray = new float[0]; 187 if (this.buffer instanceof byte[] 188 && (dataType == DataType.kNumberTypeFloat16 || dataType == DataType.kNumberTypeFloat32)) { 189 ByteBuffer byteBuffer = ByteBuffer.wrap((byte[]) this.buffer); 190 FloatBuffer floatBuffer = byteBuffer.asFloatBuffer(); 191 floatArray = new float[floatBuffer.remaining()]; 192 floatBuffer.get(floatArray); 193 } 194 return floatArray; 195 } 196 197 /** 198 * Get output data of MSTensor, the data type is int. 199 * 200 * @return The int array containing all MSTensor output data. 201 */ getIntData()202 public int[] getIntData() { 203 if (this.buffer == null) { 204 return this.getIntData(this.tensorPtr); 205 } 206 if (this.buffer instanceof int[]) { 207 return (int[]) this.buffer; 208 } 209 int dataType = this.getDataType(); 210 int[] intArray = new int[0]; 211 if (this.buffer instanceof byte[] 212 && (dataType == DataType.kNumberTypeInt32)) { 213 byte[] byteArray = (byte[]) this.buffer; 214 intArray = new int[byteArray.length]; 215 for (int i = 0; i < byteArray.length; i++) { 216 intArray[i] = byteArray[i] & 0xff; 217 } 218 } 219 return intArray; 220 } 221 222 /** 223 * Get output data of MSTensor, the data type is long. 224 * 225 * @return The long array containing all MSTensor output data. 226 */ getLongData()227 public long[] getLongData() { 228 if (this.buffer == null) { 229 return this.getLongData(this.tensorPtr); 230 } 231 if (this.buffer instanceof long[]) { 232 return (long[]) this.buffer; 233 } 234 int dataType = this.getDataType(); 235 long[] longArray = new long[0]; 236 if (this.buffer instanceof byte[] 237 && (dataType == DataType.kNumberTypeFloat16 || dataType == DataType.kNumberTypeFloat32)) { 238 ByteBuffer byteBuffer = ByteBuffer.wrap((byte[]) this.buffer); 239 LongBuffer longBuffer = byteBuffer.asLongBuffer(); 240 longArray = new long[longBuffer.remaining()]; 241 longBuffer.get(longArray); 242 } 243 return longArray; 244 } 245 246 /** 247 * Set the shape of MSTensor. 248 * 249 * @param tensorShape of int[] type. 250 * @return whether set shape success. 251 */ setShape(int[] tensorShape)252 public boolean setShape(int[] tensorShape) { 253 if (tensorShape == null) { 254 LOGGER.severe("input param null."); 255 return false; 256 } 257 return this.setShape(this.tensorPtr, tensorShape); 258 } 259 260 /** 261 * Set the input data of MSTensor. 262 * 263 * @param data Input data of ByteBuffer type. 264 * @return whether set data success. 265 */ setData(ByteBuffer data)266 public boolean setData(ByteBuffer data) { 267 if (data == null) { 268 LOGGER.severe("input param null."); 269 return false; 270 } 271 return this.setByteBufferData(this.tensorPtr, data); 272 } 273 274 /** 275 * Set the input data of MSTensor. 276 * 277 * @param data Input data of byte[] type. 278 * @return whether set data success. 279 */ setData(byte[] data)280 public boolean setData(byte[] data) { 281 if (data == null) { 282 LOGGER.severe("input param null."); 283 return false; 284 } 285 if (data.length != this.size()) { 286 return false; 287 } 288 this.buffer = data; 289 return true; 290 } 291 292 /** 293 * Set the input data of MSTensor. 294 * 295 * @param data Input data of float[] type. 296 * @return whether set data success. 297 */ setData(float[] data)298 public boolean setData(float[] data) { 299 if (data == null) { 300 LOGGER.severe("input param null."); 301 return false; 302 } 303 if (this.getDataType() != DataType.kNumberTypeFloat32 304 && this.getDataType() != DataType.kNumberTypeFloat16) { 305 LOGGER.severe("Data type is not consistent"); 306 return false; 307 } 308 if (data.length != this.elementsNum()) { 309 return false; 310 } 311 this.buffer = data; 312 return true; 313 } 314 315 /** 316 * Set the input data of MSTensor. 317 * 318 * @param data Input data of int[] type. 319 * @return whether set data success. 320 */ setData(int[] data)321 public boolean setData(int[] data) { 322 if (data == null) { 323 LOGGER.severe("input param null."); 324 return false; 325 } 326 if (this.getDataType() != DataType.kNumberTypeInt32) { 327 LOGGER.severe("Data type is not consistent"); 328 return false; 329 } 330 if (data.length != this.elementsNum()) { 331 return false; 332 } 333 this.buffer = data; 334 return true; 335 } 336 337 /** 338 * Set the input data of MSTensor. 339 * 340 * @param data Input data of long[] type. 341 * @return whether set data success. 342 */ setData(long[] data)343 public boolean setData(long[] data) { 344 if (data == null) { 345 LOGGER.severe("input param null."); 346 return false; 347 } 348 if (this.getDataType() != DataType.kNumberTypeInt64) { 349 LOGGER.severe("Data type is not consistent"); 350 return false; 351 } 352 if (data.length != this.elementsNum()) { 353 return false; 354 } 355 this.buffer = data; 356 return true; 357 } 358 359 /** 360 * Get the size of the data in MSTensor in bytes. 361 * 362 * @return The size of the data in MSTensor in bytes. 363 */ size()364 public long size() { 365 return this.size(this.tensorPtr); 366 } 367 368 /** 369 * Get the number of elements in MSTensor. 370 * 371 * @return The number of elements in MSTensor. 372 */ elementsNum()373 public int elementsNum() { 374 return this.elementsNum(this.tensorPtr); 375 } 376 377 /** 378 * Free all temporary memory in MindSpore MSTensor. 379 */ free()380 public void free() { 381 this.free(this.tensorPtr); 382 this.tensorPtr = POINTER_DEFAULT_VALUE; 383 this.buffer = null; 384 } 385 386 /** 387 * @return Get tensor name 388 */ tensorName()389 public String tensorName() { 390 return this.tensorName(this.tensorPtr); 391 } 392 393 /** 394 * @return MSTensor pointer 395 */ getMSTensorPtr()396 public long getMSTensorPtr() { 397 return tensorPtr; 398 } 399 ParseDataType(Object obj)400 private static int ParseDataType(Object obj) { 401 HashMap<Class<?>, Integer> classToDType = new HashMap<Class<?>, Integer>() {{ 402 put(float.class, DataType.kNumberTypeFloat32); 403 put(Float.class, DataType.kNumberTypeFloat32); 404 put(double.class, DataType.kNumberTypeFloat64); 405 put(Double.class, DataType.kNumberTypeFloat64); 406 put(int.class, DataType.kNumberTypeInt32); 407 put(Integer.class, DataType.kNumberTypeInt32); 408 put(long.class, DataType.kNumberTypeInt64); 409 put(Long.class, DataType.kNumberTypeInt64); 410 put(boolean.class, DataType.kNumberTypeBool); 411 put(Boolean.class, DataType.kNumberTypeBool); 412 }}; 413 Class<?> c = obj.getClass(); 414 while (c.isArray()) { 415 c = c.getComponentType(); 416 } 417 Integer dType = classToDType.get(c); 418 return dType == null ? 0 : dType; 419 } 420 ParseShape(Object obj)421 private static int[] ParseShape(Object obj) { 422 int i = 0; 423 Class<?> c = obj.getClass(); 424 while (c.isArray()) { 425 c = c.getComponentType(); 426 ++i; 427 } 428 int[] shape = new int[i]; 429 i = 0; 430 c = obj.getClass(); 431 while (c.isArray()) { 432 shape[i] = Array.getLength(obj); 433 if (shape[i] <= 0) { 434 return null; 435 } 436 obj = Array.get(obj, 0); 437 c = c.getComponentType(); 438 ++i; 439 } 440 return shape; 441 } 442 createTensorByNative(String tensorName, int dataType, int[] tesorShape, ByteBuffer buffer)443 private static native long createTensorByNative(String tensorName, int dataType, int[] tesorShape, 444 ByteBuffer buffer); 445 createTensorByObject(String tensorName, int dataType, int[] tesorShape, Object obj)446 private static native long createTensorByObject(String tensorName, int dataType, int[] tesorShape, 447 Object obj); 448 getShape(long tensorPtr)449 private native int[] getShape(long tensorPtr); 450 getDataType(long tensorPtr)451 private native int getDataType(long tensorPtr); 452 getByteData(long tensorPtr)453 private native byte[] getByteData(long tensorPtr); 454 getLongData(long tensorPtr)455 private native long[] getLongData(long tensorPtr); 456 getIntData(long tensorPtr)457 private native int[] getIntData(long tensorPtr); 458 getFloatData(long tensorPtr)459 private native float[] getFloatData(long tensorPtr); 460 getFloat16Data(long tensorPtr)461 private native float[] getFloat16Data(long tensorPtr); 462 setByteData(long tensorPtr, byte[] data, long dataLen)463 private native boolean setByteData(long tensorPtr, byte[] data, long dataLen); 464 setFloatData(long tensorPtr, float[] data, long dataLen)465 private native boolean setFloatData(long tensorPtr, float[] data, long dataLen); 466 setIntData(long tensorPtr, int[] data, long dataLen)467 private native boolean setIntData(long tensorPtr, int[] data, long dataLen); 468 setLongData(long tensorPtr, long[] data, long dataLen)469 private native boolean setLongData(long tensorPtr, long[] data, long dataLen); 470 setShape(long tensorPtr, int[] tensorShape)471 private native boolean setShape(long tensorPtr, int[] tensorShape); 472 setByteBufferData(long tensorPtr, ByteBuffer buffer)473 private native boolean setByteBufferData(long tensorPtr, ByteBuffer buffer); 474 size(long tensorPtr)475 private native long size(long tensorPtr); 476 elementsNum(long tensorPtr)477 private native int elementsNum(long tensorPtr); 478 free(long tensorPtr)479 private native void free(long tensorPtr); 480 tensorName(long tensorPtr)481 private native String tensorName(long tensorPtr); 482 }