• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <string>
2 #include <vector>
3 #include "primitive_check.h"
4 #include "dtype/type_id.h"
5 #include "src/weight_decoder.h"
6 #include "src/common/log.h"
7 #include "src/common/utils.h"
8 namespace mindspore {
9 namespace lite {
10 
CheckPrimitiveSupported(const schema::Primitive * primitive)11 Status CheckPrimitiveSupported(const schema::Primitive *primitive) {
12   if (primitive != nullptr) {
13     auto prim = primitive;
14     auto type = prim->value_type();
15     switch (type) {
16       case schema::PrimitiveType_Activation:
17         return mindspore::kSuccess;
18       case schema::PrimitiveType_AddFusion:
19         return mindspore::kSuccess;
20       case schema::PrimitiveType_ArgMaxFusion:
21         return mindspore::kSuccess;
22       case schema::PrimitiveType_AvgPoolFusion:
23         return mindspore::kSuccess;
24       case schema::PrimitiveType_BatchToSpaceND:
25         return mindspore::kSuccess;
26       case schema::PrimitiveType_BiasAdd:
27         return mindspore::kSuccess;
28       case schema::PrimitiveType_Cast:
29         return mindspore::kSuccess;
30       case schema::PrimitiveType_Concat:
31         return mindspore::kSuccess;
32       case schema::PrimitiveType_Conv2DFusion:
33         return mindspore::kSuccess;
34       case schema::PrimitiveType_Conv2dTransposeFusion:
35         return mindspore::kSuccess;
36       case schema::PrimitiveType_DivFusion:
37         return mindspore::kSuccess;
38       case schema::PrimitiveType_Eltwise:
39         return mindspore::kSuccess;
40       case schema::PrimitiveType_ExpandDims:
41         return mindspore::kSuccess;
42       case schema::PrimitiveType_Fill:
43         return mindspore::kSuccess;
44       case schema::PrimitiveType_FullConnection:
45         return mindspore::kSuccess;
46       case schema::PrimitiveType_FusedBatchNorm:
47         return mindspore::kSuccess;
48       case schema::PrimitiveType_Gather:
49         return mindspore::kSuccess;
50       case schema::PrimitiveType_LayerNormFusion:
51         return mindspore::kSuccess;
52       case schema::PrimitiveType_LessEqual:
53         return mindspore::kSuccess;
54       case schema::PrimitiveType_MatMulFusion:
55         return mindspore::kSuccess;
56       case schema::PrimitiveType_Maximum:
57         return mindspore::kSuccess;
58       case schema::PrimitiveType_MaxPoolFusion:
59         return mindspore::kSuccess;
60       case schema::PrimitiveType_MulFusion:
61         return mindspore::kSuccess;
62       case schema::PrimitiveType_OneHot:
63         return mindspore::kSuccess;
64       case schema::PrimitiveType_PadFusion:
65         return mindspore::kSuccess;
66       case schema::PrimitiveType_PowFusion:
67         return mindspore::kSuccess;
68       case schema::PrimitiveType_PReLUFusion:
69         return mindspore::kSuccess;
70       case schema::PrimitiveType_QuantDTypeCast:
71         return mindspore::kSuccess;
72       case schema::PrimitiveType_ReduceFusion:
73         return mindspore::kSuccess;
74       case schema::PrimitiveType_Reshape:
75         return mindspore::kSuccess;
76       case schema::PrimitiveType_Resize:
77         return mindspore::kSuccess;
78       case schema::PrimitiveType_Rsqrt:
79         return mindspore::kSuccess;
80       case schema::PrimitiveType_ScaleFusion:
81         return mindspore::kSuccess;
82       case schema::PrimitiveType_Shape:
83         return mindspore::kSuccess;
84       case schema::PrimitiveType_SliceFusion:
85         return mindspore::kSuccess;
86       case schema::PrimitiveType_Softmax:
87         return mindspore::kSuccess;
88       case schema::PrimitiveType_SpaceToBatchND:
89         return mindspore::kSuccess;
90       case schema::PrimitiveType_Split:
91         return mindspore::kSuccess;
92       case schema::PrimitiveType_Sqrt:
93         return mindspore::kSuccess;
94       case schema::PrimitiveType_SquaredDifference:
95         return mindspore::kSuccess;
96       case schema::PrimitiveType_Squeeze:
97         return mindspore::kSuccess;
98       case schema::PrimitiveType_Stack:
99         return mindspore::kSuccess;
100       case schema::PrimitiveType_StridedSlice:
101         return mindspore::kSuccess;
102       case schema::PrimitiveType_SubFusion:
103         return mindspore::kSuccess;
104       case schema::PrimitiveType_TileFusion:
105         return mindspore::kSuccess;
106       case schema::PrimitiveType_TopKFusion:
107         return mindspore::kSuccess;
108       case schema::PrimitiveType_Transpose:
109         return mindspore::kSuccess;
110       case schema::PrimitiveType_Unsqueeze:
111         return mindspore::kSuccess;
112       default: {
113         MS_LOG(WARNING) << "No primitive type :" << (int)(type);
114         return mindspore::kLiteSuccessExit;
115       }
116     }
117     return mindspore::kSuccess;
118   } else {
119     MS_LOG(ERROR) << "primitive is nullptr.";
120     return mindspore::kLiteError;
121   }
122 }
123 namespace {
NeedBitUppackCheck(const schema::Tensor & src_tensor)124 bool NeedBitUppackCheck(const schema::Tensor &src_tensor) {
125   if (src_tensor.enableHuffmanCode()) {
126     return true;
127   }
128   bool need_bit_unpack = src_tensor.quantParams() != nullptr && src_tensor.quantParams()->size() > 0 &&
129                          src_tensor.quantParams()->Get(0) != nullptr;
130   if (need_bit_unpack) {
131     auto num_bits = src_tensor.quantParams()->Get(0)->numBits();
132     need_bit_unpack = ((num_bits >= kBitNum1 && num_bits < kBitNum8) || (num_bits > kBitNum8 && num_bits < kBitNum16));
133   }
134 
135   return need_bit_unpack;
136 }
DecompressTensor(const schema::Tensor & src_tensor)137 int DecompressTensor(const schema::Tensor &src_tensor) {
138   if (src_tensor.weightQunatCompressType() == schema::WeightQunatCompressType_FSE ||
139       src_tensor.weightQunatCompressType() == schema::WeightQunatCompressType_INDEXING ||
140       src_tensor.weightQunatCompressType() == schema::WeightQunatCompressType_SPARSE) {
141     return RET_NOT_SUPPORT;
142   }
143   if (!NeedBitUppackCheck(src_tensor)) {
144     return RET_NO_CHANGE;
145   }
146   MS_LOG(ERROR) << "DecompressTensor Error.";
147   return RET_ERROR;
148 }
149 }  // namespace
150 
CheckTensorSupported(const schema::Tensor * primitive)151 Status CheckTensorSupported(const schema::Tensor *primitive) {
152   if (primitive == nullptr) {
153     MS_LOG(ERROR) << "primitive is nullptr, which type is Tensor.";
154     return mindspore::kLiteSuccessExit;
155   }
156 
157   int32_t data_type = primitive->dataType();
158   if (data_type <= kTypeUnknown || data_type >= kMonadTypeEnd) {
159     MS_LOG(ERROR) << "invalid data type. " << data_type;
160     return mindspore::kLiteSuccessExit;
161   }
162 
163   if (primitive->dims() == nullptr) {
164     MS_LOG(DEBUG) << "Dims of tensor is nullptr";
165   }
166 
167   if (data_type == kObjectTypeTensorType) {
168     MS_LOG(ERROR) << "Not support TensorList.";
169     return mindspore::kLiteNotSupport;
170   }
171 
172   if (primitive->data() == nullptr || primitive->data()->size() <= 0) {
173     MS_LOG(DEBUG) << "No valid data converted.";
174     return mindspore::kSuccess;
175   } else {
176     auto ret = DecompressTensor(*primitive);
177     if (ret == RET_NO_CHANGE) {
178     } else {
179       MS_LOG(ERROR) << "Not support Decompress Tensor.";
180       return mindspore::kLiteNotSupport;
181     }
182   }
183   return mindspore::kSuccess;
184   ;
185 }
186 }  // namespace lite
187 }  // namespace mindspore
188