• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <string>
2 #include <vector>
3 #include "primitive_check.h"
4 #include "dtype/type_id.h"
5 #include "src/litert/weight_decoder.h"
6 #include "src/common/log.h"
7 #include "src/common/utils.h"
8 namespace mindspore {
9 namespace lite {
10 namespace {
NeedBitUppackCheck(const schema::Tensor & src_tensor)11 bool NeedBitUppackCheck(const schema::Tensor &src_tensor) {
12   if (src_tensor.enableHuffmanCode()) {
13     return true;
14   }
15   bool need_bit_unpack = src_tensor.quantParams() != nullptr && src_tensor.quantParams()->size() > 0 &&
16                          src_tensor.quantParams()->Get(0) != nullptr;
17   if (need_bit_unpack) {
18     auto num_bits = src_tensor.quantParams()->Get(0)->numBits();
19     need_bit_unpack = ((num_bits >= kBitNum1 && num_bits < kBitNum8) || (num_bits > kBitNum8 && num_bits < kBitNum16));
20   }
21 
22   return need_bit_unpack;
23 }
DecompressTensor(const schema::Tensor & src_tensor)24 int DecompressTensor(const schema::Tensor &src_tensor) {
25   if (src_tensor.weightQuantCompressType() == schema::WeightQuantCompressType_FSE ||
26       src_tensor.weightQuantCompressType() == schema::WeightQuantCompressType_INDEXING ||
27       src_tensor.weightQuantCompressType() == schema::WeightQuantCompressType_SPARSE) {
28     return RET_NOT_SUPPORT;
29   }
30   if (!NeedBitUppackCheck(src_tensor)) {
31     return RET_NO_CHANGE;
32   }
33   MS_LOG(ERROR) << "DecompressTensor Error.";
34   return RET_ERROR;
35 }
36 }  // namespace
37 
CheckTensorSupported(const schema::Tensor * primitive)38 Status CheckTensorSupported(const schema::Tensor *primitive) {
39   if (primitive == nullptr) {
40     MS_LOG(ERROR) << "primitive is nullptr, which type is Tensor.";
41     return mindspore::kLiteSuccessExit;
42   }
43 
44   int32_t data_type = primitive->dataType();
45   if (data_type <= kTypeUnknown || data_type >= kMonadTypeEnd) {
46     MS_LOG(ERROR) << "invalid data type. " << data_type;
47     return mindspore::kLiteSuccessExit;
48   }
49 
50   if (primitive->dims() == nullptr) {
51     MS_LOG(DEBUG) << "Dims of tensor is nullptr";
52   }
53 
54   if (data_type == kObjectTypeTensorType) {
55     MS_LOG(ERROR) << "Not support TensorList.";
56     return mindspore::kLiteNotSupport;
57   }
58 
59   if (primitive->data() == nullptr || primitive->data()->size() <= 0) {
60     MS_LOG(DEBUG) << "No valid data converted.";
61     return mindspore::kSuccess;
62   } else {
63     auto ret = DecompressTensor(*primitive);
64     if (ret == RET_NO_CHANGE) {
65     } else {
66       MS_LOG(DEBUG) << "Not support Decompress Tensor.";
67       return mindspore::kLiteNotSupport;
68     }
69   }
70   return mindspore::kSuccess;
71   ;
72 }
73 }  // namespace lite
74 }  // namespace mindspore
75