1 #include <string> 2 #include <vector> 3 #include "primitive_check.h" 4 #include "dtype/type_id.h" 5 #include "src/weight_decoder.h" 6 #include "src/common/log.h" 7 #include "src/common/utils.h" 8 namespace mindspore { 9 namespace lite { 10 CheckPrimitiveSupported(const schema::Primitive * primitive)11Status CheckPrimitiveSupported(const schema::Primitive *primitive) { 12 if (primitive != nullptr) { 13 auto prim = primitive; 14 auto type = prim->value_type(); 15 switch (type) { 16 case schema::PrimitiveType_Activation: 17 return mindspore::kSuccess; 18 case schema::PrimitiveType_AddFusion: 19 return mindspore::kSuccess; 20 case schema::PrimitiveType_ArgMaxFusion: 21 return mindspore::kSuccess; 22 case schema::PrimitiveType_AvgPoolFusion: 23 return mindspore::kSuccess; 24 case schema::PrimitiveType_BatchToSpaceND: 25 return mindspore::kSuccess; 26 case schema::PrimitiveType_BiasAdd: 27 return mindspore::kSuccess; 28 case schema::PrimitiveType_Cast: 29 return mindspore::kSuccess; 30 case schema::PrimitiveType_Concat: 31 return mindspore::kSuccess; 32 case schema::PrimitiveType_Conv2DFusion: 33 return mindspore::kSuccess; 34 case schema::PrimitiveType_Conv2dTransposeFusion: 35 return mindspore::kSuccess; 36 case schema::PrimitiveType_DivFusion: 37 return mindspore::kSuccess; 38 case schema::PrimitiveType_Eltwise: 39 return mindspore::kSuccess; 40 case schema::PrimitiveType_ExpandDims: 41 return mindspore::kSuccess; 42 case schema::PrimitiveType_Fill: 43 return mindspore::kSuccess; 44 case schema::PrimitiveType_FullConnection: 45 return mindspore::kSuccess; 46 case schema::PrimitiveType_FusedBatchNorm: 47 return mindspore::kSuccess; 48 case schema::PrimitiveType_Gather: 49 return mindspore::kSuccess; 50 case schema::PrimitiveType_LayerNormFusion: 51 return mindspore::kSuccess; 52 case schema::PrimitiveType_LessEqual: 53 return mindspore::kSuccess; 54 case schema::PrimitiveType_MatMulFusion: 55 return mindspore::kSuccess; 56 case schema::PrimitiveType_Maximum: 57 return mindspore::kSuccess; 58 case schema::PrimitiveType_MaxPoolFusion: 59 return mindspore::kSuccess; 60 case schema::PrimitiveType_MulFusion: 61 return mindspore::kSuccess; 62 case schema::PrimitiveType_OneHot: 63 return mindspore::kSuccess; 64 case schema::PrimitiveType_PadFusion: 65 return mindspore::kSuccess; 66 case schema::PrimitiveType_PowFusion: 67 return mindspore::kSuccess; 68 case schema::PrimitiveType_PReLUFusion: 69 return mindspore::kSuccess; 70 case schema::PrimitiveType_QuantDTypeCast: 71 return mindspore::kSuccess; 72 case schema::PrimitiveType_ReduceFusion: 73 return mindspore::kSuccess; 74 case schema::PrimitiveType_Reshape: 75 return mindspore::kSuccess; 76 case schema::PrimitiveType_Resize: 77 return mindspore::kSuccess; 78 case schema::PrimitiveType_Rsqrt: 79 return mindspore::kSuccess; 80 case schema::PrimitiveType_ScaleFusion: 81 return mindspore::kSuccess; 82 case schema::PrimitiveType_Shape: 83 return mindspore::kSuccess; 84 case schema::PrimitiveType_SliceFusion: 85 return mindspore::kSuccess; 86 case schema::PrimitiveType_Softmax: 87 return mindspore::kSuccess; 88 case schema::PrimitiveType_SpaceToBatchND: 89 return mindspore::kSuccess; 90 case schema::PrimitiveType_Split: 91 return mindspore::kSuccess; 92 case schema::PrimitiveType_Sqrt: 93 return mindspore::kSuccess; 94 case schema::PrimitiveType_SquaredDifference: 95 return mindspore::kSuccess; 96 case schema::PrimitiveType_Squeeze: 97 return mindspore::kSuccess; 98 case schema::PrimitiveType_Stack: 99 return mindspore::kSuccess; 100 case schema::PrimitiveType_StridedSlice: 101 return mindspore::kSuccess; 102 case schema::PrimitiveType_SubFusion: 103 return mindspore::kSuccess; 104 case schema::PrimitiveType_TileFusion: 105 return mindspore::kSuccess; 106 case schema::PrimitiveType_TopKFusion: 107 return mindspore::kSuccess; 108 case schema::PrimitiveType_Transpose: 109 return mindspore::kSuccess; 110 case schema::PrimitiveType_Unsqueeze: 111 return mindspore::kSuccess; 112 default: { 113 MS_LOG(WARNING) << "No primitive type :" << (int)(type); 114 return mindspore::kLiteSuccessExit; 115 } 116 } 117 return mindspore::kSuccess; 118 } else { 119 MS_LOG(ERROR) << "primitive is nullptr."; 120 return mindspore::kLiteError; 121 } 122 } 123 namespace { NeedBitUppackCheck(const schema::Tensor & src_tensor)124bool NeedBitUppackCheck(const schema::Tensor &src_tensor) { 125 if (src_tensor.enableHuffmanCode()) { 126 return true; 127 } 128 bool need_bit_unpack = src_tensor.quantParams() != nullptr && src_tensor.quantParams()->size() > 0 && 129 src_tensor.quantParams()->Get(0) != nullptr; 130 if (need_bit_unpack) { 131 auto num_bits = src_tensor.quantParams()->Get(0)->numBits(); 132 need_bit_unpack = ((num_bits >= kBitNum1 && num_bits < kBitNum8) || (num_bits > kBitNum8 && num_bits < kBitNum16)); 133 } 134 135 return need_bit_unpack; 136 } DecompressTensor(const schema::Tensor & src_tensor)137int DecompressTensor(const schema::Tensor &src_tensor) { 138 if (src_tensor.weightQunatCompressType() == schema::WeightQunatCompressType_FSE || 139 src_tensor.weightQunatCompressType() == schema::WeightQunatCompressType_INDEXING || 140 src_tensor.weightQunatCompressType() == schema::WeightQunatCompressType_SPARSE) { 141 return RET_NOT_SUPPORT; 142 } 143 if (!NeedBitUppackCheck(src_tensor)) { 144 return RET_NO_CHANGE; 145 } 146 MS_LOG(ERROR) << "DecompressTensor Error."; 147 return RET_ERROR; 148 } 149 } // namespace 150 CheckTensorSupported(const schema::Tensor * primitive)151Status CheckTensorSupported(const schema::Tensor *primitive) { 152 if (primitive == nullptr) { 153 MS_LOG(ERROR) << "primitive is nullptr, which type is Tensor."; 154 return mindspore::kLiteSuccessExit; 155 } 156 157 int32_t data_type = primitive->dataType(); 158 if (data_type <= kTypeUnknown || data_type >= kMonadTypeEnd) { 159 MS_LOG(ERROR) << "invalid data type. " << data_type; 160 return mindspore::kLiteSuccessExit; 161 } 162 163 if (primitive->dims() == nullptr) { 164 MS_LOG(DEBUG) << "Dims of tensor is nullptr"; 165 } 166 167 if (data_type == kObjectTypeTensorType) { 168 MS_LOG(ERROR) << "Not support TensorList."; 169 return mindspore::kLiteNotSupport; 170 } 171 172 if (primitive->data() == nullptr || primitive->data()->size() <= 0) { 173 MS_LOG(DEBUG) << "No valid data converted."; 174 return mindspore::kSuccess; 175 } else { 176 auto ret = DecompressTensor(*primitive); 177 if (ret == RET_NO_CHANGE) { 178 } else { 179 MS_LOG(ERROR) << "Not support Decompress Tensor."; 180 return mindspore::kLiteNotSupport; 181 } 182 } 183 return mindspore::kSuccess; 184 ; 185 } 186 } // namespace lite 187 } // namespace mindspore 188