1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
18 #include <cmath>
19 #include <string>
20 #include <map>
21 #include <fstream>
22 #include <algorithm>
23 #include <memory>
24 #include <vector>
25 #include <set>
26 #include <functional>
27 #include "include/version.h"
28 #include "ops/fusion/conv2d_fusion.h"
29 #include "ops/fusion/conv2d_transpose_fusion.h"
30 #include "ops/fusion/full_connection.h"
31 #include "ops/mat_mul.h"
32 #include "tools/converter/ops/ops_def.h"
33 #include "tools/anf_exporter/anf_exporter.h"
34 #include "tools/converter/quantizer/bitpacking.h"
35 #include "src/common/utils.h"
36 #include "tools/common/tensor_util.h"
37 #include "abstract/abstract_value.h"
38 #include "securec/include/securec.h"
39 #include "tools/optimizer/common/gllo_utils.h"
40 #include "tools/optimizer/common/format_utils.h"
41
42 using std::string;
43 using std::vector;
44
45 namespace mindspore::lite::quant {
46 const std::vector<std::string> QuantStrategy::conv_types_ = {ops::kNameConv2DFusion, ops::kNameConv2dTransposeFusion};
47 const std::vector<std::string> QuantStrategy::mul_types_ = {ops::kNameMatMul, ops::kNameFullConnection};
48 constexpr int kDim2 = 2;
49 constexpr int kDim4 = 4;
50
51 const int kLstmInputWeightIndex = 1;
52 const int kLstmStateWeightIndex = 2;
53 const int kLstmWeightShapeSize = 3;
54 const int kSingleDirBiasTensorSize = 4;
55 const int kLstmBiasShapeSize = 2;
56 const int kLstmBiasIndex = 3;
57
QuantStrategy(size_t weight_size,size_t conv_weight_quant_channel_threshold)58 QuantStrategy::QuantStrategy(size_t weight_size, size_t conv_weight_quant_channel_threshold)
59 : m_weight_size_(weight_size), m_conv_weight_quant_channel_threshold_(conv_weight_quant_channel_threshold) {}
60
CanConvOpQuantized(const CNodePtr & node) const61 bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
62 MS_CHECK_TRUE_RET(node != nullptr, false);
63 auto primitive_c = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(node->input(0));
64 if (primitive_c == nullptr) {
65 MS_LOG(ERROR) << "primitive_c is nullptr";
66 return false;
67 }
68 if (!IsContain(conv_types_, primitive_c->name())) {
69 return false;
70 }
71 if (node->size() < 3) {
72 return false;
73 }
74 auto inputNode = node->input(2);
75 if (!inputNode->isa<Parameter>()) {
76 return false;
77 }
78 auto paramNode = inputNode->cast<ParameterPtr>();
79 MS_ASSERT(paramNode != nullptr);
80 auto abstract_base = paramNode->abstract();
81 if (abstract_base == nullptr) {
82 return false;
83 }
84 if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
85 MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
86 return false;
87 }
88 auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
89 size_t shapeSize = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>());
90 if (shapeSize < m_weight_size_) {
91 MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
92 return false;
93 }
94 if (weight_shape[0] <= static_cast<int>(m_conv_weight_quant_channel_threshold_)) {
95 MS_LOG(INFO) << "channel less m_conv_weight_quant_channel_threshold_!" << weight_shape[0];
96 return false;
97 }
98 return true;
99 }
100
CanOpFullQuantized(const AnfNodePtr & node)101 bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
102 MS_CHECK_TRUE_RET(node != nullptr, false);
103 if (!node->isa<mindspore::CNode>()) {
104 return false;
105 }
106 const auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
107 MS_ASSERT(cnode != nullptr);
108 auto type = NodePrimitiveType(cnode);
109 static const std::set<PrimitivePtr> support_int8_ops = {prim::kPrimAddFusion, prim::kPrimActivation,
110 prim::kPrimAvgPoolFusion, prim::kPrimConcat,
111 prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
112 prim::kPrimCrop, prim::kPrimFullConnection,
113 prim::kPrimGather, prim::kPrimLayerNormFusion,
114 prim::kPrimMatMul, prim::kPrimMaxPoolFusion,
115 prim::kPrimMulFusion, prim::kPrimReshape,
116 prim::kPrimSplit, prim::kPrimTranspose,
117 prim::kPrimReduceFusion, prim::kPrimDivFusion,
118 prim::kPrimSqrt, prim::kPrimPowFusion,
119 prim::kPrimSubFusion, prim::kPrimUnsqueeze,
120 prim::kPrimLayerNormFusion};
121 // The return node does not need to be quantified.
122 if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn) || opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
123 return false;
124 }
125 // These operators do not need to check the data type.
126 if (opt::CheckPrimitiveType(cnode, prim::kPrimShape) || opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
127 return true;
128 }
129 auto is_support_node = CheckNodeInSet(cnode, support_int8_ops);
130 if (!is_support_node && type != "Eltwise") {
131 MS_LOG(WARNING) << "node:" << cnode->fullname_with_scope() << " type:" << type << " is not support quantization.";
132 return false;
133 }
134 TypeId type_id;
135 auto ret = opt::GetDataTypeFromAnfNode(cnode, &type_id);
136 if (ret != RET_OK) {
137 MS_LOG(ERROR) << "Fetch DataType from cnode failed.";
138 return ret;
139 }
140
141 bool is_data_type_fp32 = type_id == kNumberTypeFloat32;
142 if (!is_data_type_fp32) {
143 MS_LOG(INFO) << cnode->fullname_with_scope() << " type_id is " << type_id << " , and is not float32.";
144 }
145 return is_data_type_fp32;
146 }
147
CanMulOpQuantized(const CNodePtr & node) const148 bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
149 MS_CHECK_TRUE_RET(node != nullptr, false);
150 auto primitive_c = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(node->input(0));
151 if (primitive_c == nullptr) {
152 MS_LOG(ERROR) << "primitive_c is nullptr";
153 return false;
154 }
155
156 if (!IsContain(mul_types_, primitive_c->name())) {
157 return false;
158 }
159
160 if (node->size() < 3) {
161 MS_LOG(INFO) << node->fullname_with_scope() << " input size less!";
162 return false;
163 }
164
165 auto inputNode1 = node->input(1);
166 auto inputNode2 = node->input(2);
167 if (inputNode1 == nullptr || inputNode2 == nullptr) {
168 MS_LOG(INFO) << node->fullname_with_scope() << " mul input is nullptr!";
169 return false;
170 }
171
172 ParameterPtr paramNode = nullptr;
173 if (inputNode1->isa<Parameter>()) {
174 paramNode = inputNode1->cast<ParameterPtr>();
175 } else if (inputNode2->isa<Parameter>()) {
176 paramNode = inputNode2->cast<ParameterPtr>();
177 }
178 if (paramNode == nullptr) {
179 MS_LOG(INFO) << node->fullname_with_scope() << " invalid paramNode!";
180 return false;
181 }
182
183 auto abstract_base = paramNode->abstract();
184 if (abstract_base == nullptr) {
185 MS_LOG(INFO) << "abstract is nullptr";
186 return false;
187 }
188
189 if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
190 MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
191 return false;
192 }
193 auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
194 size_t shapeSize = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>());
195 if (shapeSize < m_weight_size_) {
196 MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
197 return false;
198 }
199 return true;
200 }
201
CanTensorQuantized(const AnfNodePtr & inputNode) const202 bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &inputNode) const {
203 if (inputNode == nullptr) {
204 MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
205 return false;
206 }
207 ParameterPtr paramNode = nullptr;
208 if (inputNode->isa<Parameter>()) {
209 paramNode = inputNode->cast<ParameterPtr>();
210 }
211 if (paramNode == nullptr) {
212 MS_LOG(INFO) << "CanTensorQuantized invalid paramNode!";
213 return false;
214 }
215 auto abstract_base = paramNode->abstract();
216 if (abstract_base == nullptr) {
217 MS_LOG(INFO) << "abstract is nullptr";
218 return false;
219 }
220 if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
221 MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
222 return false;
223 }
224 auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
225 MS_ASSERT(weight_shape != nullptr);
226 if (weight_shape.size() < kDim2) { // do not quant single dim tensors
227 return false;
228 }
229 size_t shapeSize = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>());
230 if (shapeSize < m_weight_size_) {
231 MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
232 return false;
233 }
234 if (weight_shape.size() == kDim4) { // assume Convolution
235 if (weight_shape[0] <= static_cast<int>(m_conv_weight_quant_channel_threshold_)) {
236 MS_LOG(INFO) << "channel less m_conv_weight_quant_channel_threshold_!" << weight_shape[0];
237 return false;
238 }
239 }
240
241 return true;
242 }
243
GetCNodeQuantHolder(const PrimitivePtr & primitive)244 QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
245 MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
246 QuantParamHolderPtr quant_params_holder = nullptr;
247 auto quant_params_valueptr = primitive->GetAttr("quant_params");
248 if (quant_params_valueptr == nullptr) {
249 quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
250 MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
251 primitive->AddAttr("quant_params", quant_params_holder);
252 } else {
253 quant_params_holder = quant_params_valueptr->cast<QuantParamHolderPtr>();
254 if (quant_params_holder == nullptr) {
255 quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
256 MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
257 primitive->AddAttr("quant_params", quant_params_holder);
258 }
259 }
260 return quant_params_holder;
261 }
262
TensorQuantParamsInited(const schema::TensorT & tensor)263 bool TensorQuantParamsInited(const schema::TensorT &tensor) {
264 if (tensor.quantParams.empty()) {
265 return false;
266 }
267
268 for (auto &quant_param : tensor.quantParams) {
269 if (!quant_param->inited) {
270 return false;
271 }
272 }
273 return true;
274 }
275
CalQuantizationParams(schema::QuantParamT * quantParam,double mMin,double mMax,bool narrowRange,int numBits)276 STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int numBits) {
277 MS_ASSERT(quantParam != nullptr);
278 if (mMin > 0.0f) {
279 MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
280 mMin = 0.0f;
281 }
282 if (mMax < 0.0f) {
283 MS_LOG(DEBUG) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
284 mMax = 0.0f;
285 }
286 if (mMin > mMax) {
287 MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
288 return RET_PARAM_INVALID;
289 }
290 if (mMin == mMax) {
291 if (mMin != 0.0f) {
292 MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
293 return RET_ERROR;
294 }
295 quantParam->inited = true;
296 quantParam->min = mMin;
297 quantParam->max = mMax;
298 quantParam->scale = 0.0f;
299 quantParam->zeroPoint = 0;
300 quantParam->narrowRange = narrowRange;
301 quantParam->numBits = numBits;
302 return RET_OK;
303 }
304
305 const int8_t quantMax = (1 << (static_cast<unsigned int>(numBits - 1))) - 1;
306 const int8_t quantMin = -1 * (1 << (static_cast<unsigned int>(numBits - 1))) + (narrowRange ? 1 : 0);
307 auto quantMinFloat = static_cast<double>(quantMin);
308 auto quantMaxFloat = static_cast<double>(quantMax);
309 if (fabs(quantMaxFloat - quantMinFloat) <= 0.0f) {
310 MS_LOG(ERROR) << "divisor cannot be 0";
311 return RET_ERROR;
312 }
313 double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
314 if (fabs(scale) <= 0.0f) {
315 MS_LOG(ERROR) << "divisor 'scale' cannot be 0";
316 return RET_ERROR;
317 }
318 const double zeroPointFromMin = quantMinFloat - mMin / scale;
319 const double zeroPointFromMax = quantMaxFloat - mMax / scale;
320 const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale);
321 const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale);
322 const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax;
323 int zeroPoint;
324 if (zpDouble < quantMinFloat) {
325 zeroPoint = quantMin;
326 } else if (zpDouble > quantMaxFloat) {
327 zeroPoint = quantMax;
328 } else {
329 zeroPoint = static_cast<int32_t>(std::round(zpDouble));
330 }
331 if (std::abs(mMin) == std::abs(mMax)) {
332 zeroPoint = 0;
333 }
334 // The zero point should always be in the range of quantized value,
335 // [qmin, qmax].
336 MS_ASSERT(zeroPoint >= quantMin);
337 MS_ASSERT(zeroPoint <= quantMax);
338 quantParam->inited = true;
339 quantParam->min = mMin;
340 quantParam->max = mMax;
341 quantParam->scale = scale;
342 quantParam->zeroPoint = zeroPoint;
343 quantParam->narrowRange = narrowRange;
344 quantParam->numBits = numBits;
345
346 return RET_OK;
347 }
348
SearchLowerBound(const std::vector<float> & data,const size_t & index,const float & max_tmp,float * min_tmp,size_t * min_idx)349 static bool SearchLowerBound(const std::vector<float> &data, const size_t &index, const float &max_tmp, float *min_tmp,
350 size_t *min_idx) {
351 MS_ASSERT(!data.empty());
352 size_t length = data.size();
353 if (max_tmp - data.at(index) < delta) {
354 return false;
355 }
356 if (fabs(max_tmp - *min_tmp) <= 0.0f || fabs(length - *min_idx) <= 0.0f) {
357 MS_LOG(INFO) << "divisor cannot be 0";
358 return false;
359 }
360 float range_ratio = (data.at(index) - *min_tmp) / (max_tmp - *min_tmp);
361 float index_ratio = static_cast<float>(index - *min_idx) / (length - *min_idx);
362 if (fabs(index_ratio) <= 0.0f) {
363 MS_LOG(INFO) << "divisor cannot be 0";
364 return false;
365 }
366 if (index_ratio > 0 && range_ratio / index_ratio > ratio) {
367 *min_idx = index;
368 *min_tmp = data.at(index);
369 }
370 return true;
371 }
372
SearchUpperBound(const std::vector<float> & data,const size_t & index,float * max_tmp,const float & min_tmp,size_t * max_idx)373 static bool SearchUpperBound(const std::vector<float> &data, const size_t &index, float *max_tmp, const float &min_tmp,
374 size_t *max_idx) {
375 MS_ASSERT(!data.empty());
376 size_t length = data.size();
377 if (data.at(index) - min_tmp < delta) {
378 return false;
379 }
380 if (fabs(*max_tmp - min_tmp) <= 0.0f || fabs(length - *max_idx) <= 0.0f) {
381 MS_LOG(INFO) << "divisor cannot be 0";
382 return false;
383 }
384 float range_ratio = (*max_tmp - data.at(index)) / (*max_tmp - min_tmp);
385 float index_ratio = static_cast<float>(index - *max_idx) / (length - *max_idx);
386 if (fabs(index_ratio) <= 0.0f) {
387 MS_LOG(INFO) << "divisor cannot be 0";
388 return false;
389 }
390 if (index_ratio > 0 && range_ratio / index_ratio > ratio) {
391 *max_idx = index;
392 *max_tmp = data.at(index);
393 }
394 return true;
395 }
396
CalPercentile(const std::vector<float> & data,const int & outlier_percent)397 static float CalPercentile(const std::vector<float> &data, const int &outlier_percent) {
398 MS_ASSERT(!data.empty());
399 const int size = data.size();
400 float val = outlier_percent / kPercentBase * size;
401 int index = std::ceil(val);
402 float result;
403 if (index - val > 0) {
404 MS_ASSERT(index - 1 >= 0);
405 result = data.at(index - 1);
406 } else {
407 MS_ASSERT(index - 1 >= 0);
408 result = (data.at(index - 1) + data.at(index)) / 2;
409 }
410 return result;
411 }
412
OutlierMethod(std::vector<float> min_datas,std::vector<float> max_datas)413 std::pair<float, float> OutlierMethod(std::vector<float> min_datas, std::vector<float> max_datas) {
414 MS_ASSERT(!min_datas.empty());
415 MS_ASSERT(!max_datas.empty());
416 std::sort(max_datas.begin(), max_datas.end());
417 std::sort(min_datas.begin(), min_datas.end());
418 float min_val = CalPercentile(min_datas, percent);
419 float max_val = CalPercentile(max_datas, kPercentBase - percent);
420 std::reverse(max_datas.begin(), max_datas.end());
421 MS_ASSERT(min_val < max_val);
422 MS_ASSERT(min_datas.size() == max_datas.size());
423 float min_tmp = min_val;
424 float max_tmp = max_val;
425 size_t min_idx = 0;
426 size_t max_idx = 0;
427 size_t length = min_datas.size();
428 for (size_t i = 0; i < length; i++) {
429 if (!SearchLowerBound(min_datas, i, max_tmp, &min_tmp, &min_idx)) {
430 break;
431 }
432 if (!SearchUpperBound(min_datas, i, &max_tmp, min_tmp, &max_idx)) {
433 break;
434 }
435 }
436 std::pair<float, float> result{min_tmp, max_tmp};
437 return result;
438 }
439
InitClusters(float * data,size_t elem_count,size_t k)440 static std::vector<float> InitClusters(float *data, size_t elem_count, size_t k) {
441 MS_ASSERT(data != nullptr);
442 std::set<float> set_unique{};
443 for (size_t i = 0; i < elem_count; i++) {
444 set_unique.emplace(data[i]);
445 }
446 std::vector<float> data_unique;
447 data_unique.assign(set_unique.begin(), set_unique.end());
448 std::vector<float> clusters{};
449 if (set_unique.size() < k) {
450 return clusters;
451 }
452 // init cluster
453 MS_ASSERT(k != 1);
454 float cluster_ratio = static_cast<float>(data_unique.size()) / (k - 1);
455 std::sort(data_unique.begin(), data_unique.end());
456 for (size_t i = 0; i < k; i++) {
457 size_t index = std::floor(i * cluster_ratio);
458 if (i * cluster_ratio - index > 0) {
459 clusters.emplace_back((data_unique[index] + data_unique[index + 1]) / 2);
460 } else {
461 clusters.emplace_back(data_unique[index]);
462 }
463 }
464 return clusters;
465 }
466
KMeans(float * data,size_t elem_count,size_t k,size_t epochs,schema::QuantParamT * quantParam)467 std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) {
468 MS_ASSERT(data != nullptr);
469 MS_CHECK_TRUE_MSG(elem_count != 0, std::vector<int8_t>{}, "elem_count is zero.");
470 std::vector<float> clusters = InitClusters(data, elem_count, k);
471 std::vector<int8_t> clusters_index{};
472 double error{0};
473 if (clusters.size() < k) {
474 MS_LOG(WARNING) << "K is less than the size of data so KMeans function is not executed.";
475 return clusters_index;
476 }
477 for (size_t epoch = 0; epoch < epochs; epoch++) {
478 double error_cur{0};
479 clusters_index.clear();
480 std::vector<std::vector<float>> clusters_data(clusters.size());
481 for (size_t i = 0; i < elem_count; i++) {
482 size_t index = 0;
483 float min_distance = pow(data[i] - clusters[0], 2);
484 for (size_t j = 1; j < clusters.size(); j++) {
485 if (pow(data[i] - clusters[j], 2) < min_distance) {
486 min_distance = pow(data[i] - clusters[j], 2);
487 index = j;
488 }
489 }
490 clusters_index.emplace_back(index + INT8_MIN);
491 clusters_data[index].emplace_back(data[i]);
492 }
493 for (size_t j = 0; j < clusters.size(); j++) {
494 if (!clusters_data[j].empty()) {
495 clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size();
496 }
497 }
498 // compare error
499 for (size_t j = 0; j < elem_count; j++) {
500 error_cur += pow(data[j] - clusters[clusters_index[j]], 2);
501 }
502 error_cur = pow(error_cur / elem_count, 0.5);
503 if (std::abs((error_cur - error) / error_cur) <= 0.0f) {
504 break;
505 }
506 error = error_cur;
507 }
508 // update data
509 return clusters_index;
510 }
511
NodePrimitiveType(const CNodePtr & cnode)512 std::string NodePrimitiveType(const CNodePtr &cnode) {
513 if (cnode == nullptr) {
514 MS_LOG(ERROR) << "cnode is null";
515 return "";
516 }
517 auto primitive_c = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
518 if (primitive_c == nullptr) {
519 MS_LOG(ERROR) << "primitive_c is null";
520 return "";
521 }
522 return primitive_c->name();
523 }
524
CreateSessionByFuncGraph(const FuncGraphPtr & func_graph,const converter::Flags & flags,int thread_num)525 SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num) {
526 SessionModel sm;
527 auto meta_graph = Export(func_graph, true, true);
528 if (meta_graph == nullptr) {
529 MS_LOG(ERROR) << "Export to meta_graph failed";
530 return sm;
531 }
532
533 // transform
534 GraphDefTransform fb_transform;
535 fb_transform.SetGraphDef(meta_graph);
536 auto status = fb_transform.Transform(flags);
537 if (status != RET_OK) {
538 MS_LOG(ERROR) << "FBTransform model failed";
539 return sm;
540 }
541 meta_graph->version = Version();
542
543 flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
544 auto offset = schema::MetaGraph::Pack(builder, meta_graph);
545 builder.Finish(offset);
546 schema::FinishMetaGraphBuffer(builder, offset);
547 auto size = builder.GetSize();
548 auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
549 if (content == nullptr) {
550 MS_LOG(ERROR) << "GetBufferPointer return null";
551 return sm;
552 }
553 auto model = lite::Model::Import(content, size);
554 if (model == nullptr) {
555 MS_LOG(ERROR) << "Import model failed";
556 return sm;
557 }
558 Context ctx;
559 ctx.thread_num_ = thread_num;
560 auto session = session::LiteSession::CreateSession(&ctx);
561 if (session == nullptr) {
562 MS_LOG(ERROR) << "create session failed.";
563 model->Free();
564 delete meta_graph;
565 return sm;
566 }
567
568 status = session->CompileGraph(model);
569 if (status != RET_OK) {
570 MS_LOG(ERROR) << "CompileGraph error";
571 model->Free();
572 delete meta_graph;
573 delete session;
574 return sm;
575 }
576 model->Free();
577 delete meta_graph;
578 sm.session = session;
579 sm.model = model;
580 return sm;
581 }
582
CopyFuncGraph(const FuncGraphPtr & func_graph)583 FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) {
584 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
585 Cloner cloner({func_graph}, true, true, true, std::make_shared<TraceCopy>(), nullptr);
586 auto new_func_graph = cloner[func_graph];
587
588 std::map<std::string, CNodePtr> old_cnode_map;
589 for (const auto &cnode : func_graph->GetOrderedCnodes()) {
590 old_cnode_map[cnode->fullname_with_scope()] = cnode;
591 }
592
593 for (auto &cnode : new_func_graph->GetOrderedCnodes()) {
594 auto cnode_name = cnode->fullname_with_scope();
595 auto old_cnode_iter = old_cnode_map.find(cnode_name);
596 if (old_cnode_iter == old_cnode_map.end()) {
597 MS_LOG(ERROR) << "can not find node: " << cnode_name;
598 return nullptr;
599 }
600 auto old_cnode = old_cnode_iter->second;
601 auto inputs = cnode->inputs();
602 for (const auto &input_node : inputs) {
603 if (input_node->isa<Parameter>()) {
604 auto param_node = input_node->cast<ParameterPtr>();
605 if (!param_node->has_default()) {
606 MS_LOG(ERROR) << "Param node has no default parameter: " << cnode_name;
607 return nullptr;
608 }
609 auto old_tensor_info = std::static_pointer_cast<tensor::Tensor>(param_node->default_param());
610 if (old_tensor_info == nullptr) {
611 MS_LOG(ERROR) << "Default param of param node is not a tensor info:" << cnode_name;
612 return nullptr;
613 }
614 auto new_tensor_info = lite::CreateTensorInfo(old_tensor_info->data().data(), old_tensor_info->data().nbytes(),
615 old_tensor_info->shape(), old_tensor_info->data_type());
616 if (new_tensor_info == nullptr) {
617 MS_LOG(ERROR) << "Create tensor info failed";
618 return nullptr;
619 }
620 auto status = lite::InitParameterFromTensorInfo(param_node, new_tensor_info);
621 if (status != RET_OK) {
622 MS_LOG(ERROR) << "init parameter from tensor info failed";
623 return nullptr;
624 }
625 }
626 } // end inputs loop
627 } // end cnodes loop
628 return new_func_graph;
629 }
630
GetLiteParameter(const AnfNodePtr & node,ParameterPtr * param_node,tensor::TensorPtr * tensor_info)631 void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
632 if (node == nullptr) {
633 MS_LOG(ERROR) << "node is nullptr";
634 return;
635 }
636 auto op_name = node->fullname_with_scope();
637
638 *param_node = node->cast<ParameterPtr>();
639 if (*param_node == nullptr) {
640 MS_LOG(INFO) << op_name << " can not cast to ParameterPtr";
641 return;
642 }
643 if (!(*param_node)->has_default()) {
644 MS_LOG(INFO) << op_name << " not has_default";
645 return;
646 }
647
648 *tensor_info = std::static_pointer_cast<tensor::Tensor>((*param_node)->default_param());
649 if (*tensor_info == nullptr) {
650 MS_LOG(INFO) << "default_param can not cast to tensor::Tensor";
651 return;
652 }
653 }
654
UpdateTensorDataAndSize(const tensor::TensorPtr & weight,void * quant_datas,int new_size,TypeId new_data_type)655 STATUS UpdateTensorDataAndSize(const tensor::TensorPtr &weight, void *quant_datas, int new_size, TypeId new_data_type) {
656 MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
657 MS_CHECK_TRUE_RET(new_size > 0, RET_NULL_PTR);
658 weight->set_data_type(new_data_type);
659 if (new_size != weight->data().nbytes()) {
660 MS_LOG(ERROR) << "Data size of tensor info is error.";
661 return RET_ERROR;
662 }
663 if (memcpy_s(weight->data_c(), new_size, quant_datas, new_size) != EOK) {
664 MS_LOG(ERROR) << "memcpy data failed.";
665 return RET_ERROR;
666 }
667 return RET_OK;
668 }
669
CalChannels(const ShapeVector & dims,int channel_cnt,bool * channel_at_first)670 int CalChannels(const ShapeVector &dims, int channel_cnt, bool *channel_at_first) {
671 auto channels = dims[0];
672 if (!(*channel_at_first)) {
673 if (dims.size() != 2) {
674 MS_LOG(WARNING) << "unexpected dims size: " << dims.size();
675 *channel_at_first = true;
676 } else {
677 channels = dims[1];
678 }
679 } else {
680 channels = channel_cnt == -1 ? channels : channel_cnt;
681 }
682 return channels;
683 }
684
CalQuantAssitInfo(const PrimitivePtr & primitive,const ShapeVector & shapes,int index,bool * channel_at_first,int * channel_cnt)685 void CalQuantAssitInfo(const PrimitivePtr &primitive, const ShapeVector &shapes, int index, bool *channel_at_first,
686 int *channel_cnt) {
687 MS_ASSERT(primitive != nullptr);
688 if (shapes.empty()) {
689 MS_LOG(ERROR) << " shape vector is empty.";
690 return;
691 }
692 if (primitive->name() == ops::kNameMatMul && static_cast<int>(shapes.size()) == DIMENSION_2D) {
693 auto matmul_prim = primitive->cast<std::shared_ptr<ops::MatMul>>();
694 MS_ASSERT(matmul_prim != nullptr);
695 *channel_at_first =
696 index != 1 || (matmul_prim->GetAttr(ops::kTransposeB) != nullptr && matmul_prim->get_transpose_b());
697 } else if (primitive->name() == ops::kNameLSTM) {
698 if (index == kLstmInputWeightIndex || index == kLstmStateWeightIndex) {
699 if (shapes.size() != kLstmWeightShapeSize) {
700 MS_LOG(WARNING) << "unexpected lstm shape size: " << shapes.size();
701 } else {
702 *channel_cnt = shapes[0] * shapes[1];
703 }
704 } else if (index == kLstmBiasIndex) {
705 if (shapes.size() != kLstmBiasShapeSize) {
706 MS_LOG(WARNING) << "unexpected lstm shape size: " << shapes.size();
707 } else {
708 auto tensor_elem_cnt = shapes[0] * shapes[1];
709 if (tensor_elem_cnt % kSingleDirBiasTensorSize == 0) {
710 *channel_cnt = kSingleDirBiasTensorSize;
711 }
712 }
713 } else {
714 MS_LOG(WARNING) << "unexpected index of lstm: " << index;
715 }
716 }
717 }
718
CalQuantAssitInfo(const schema::PrimitiveT & primitive,const std::vector<int> & shapes,int index,bool * channel_at_first,int * channel_cnt)719 void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<int> &shapes, int index,
720 bool *channel_at_first, int *channel_cnt) {
721 MS_ASSERT(primitive != nullptr);
722 if (shapes.empty()) {
723 MS_LOG(ERROR) << " shape vector is empty.";
724 return;
725 }
726 if (primitive.value.type == schema::PrimitiveType_MatMul && static_cast<int>(shapes.size()) == kDim2) {
727 auto matmul_prim = primitive.value.AsMatMul();
728 MS_ASSERT(matmul_prim != nullptr);
729 *channel_at_first = index != 1 || matmul_prim->transpose_b;
730 } else if (primitive.value.type == schema::PrimitiveType_LSTM) {
731 if (index == kLstmInputWeightIndex || index == kLstmStateWeightIndex) {
732 if (shapes.size() != kLstmWeightShapeSize) {
733 MS_LOG(WARNING) << "unexpected lstm shape size: " << shapes.size();
734 } else {
735 *channel_cnt = shapes[0] * shapes[1];
736 }
737 } else if (index == kLstmBiasIndex) {
738 if (shapes.size() != kLstmBiasShapeSize) {
739 MS_LOG(WARNING) << "unexpected lstm shape size: " << shapes.size();
740 } else {
741 auto tensor_elem_cnt = shapes[0] * shapes[1];
742 if (tensor_elem_cnt % kSingleDirBiasTensorSize == 0) {
743 *channel_cnt = kSingleDirBiasTensorSize;
744 }
745 }
746 } else {
747 MS_LOG(WARNING) << "unexpected index of lstm: " << index;
748 }
749 }
750 }
751
MixedBitQuantFilter(const tensor::TensorPtr & weight,const PrimitivePtr & primitive,QuantType quant_type,WeightQuantType weight_quant_type,TypeId quant_data_type,double init_scale,int index)752 STATUS MixedBitQuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
753 WeightQuantType weight_quant_type, TypeId quant_data_type, double init_scale, int index) {
754 MS_CHECK_TRUE_RET(primitive != nullptr, RET_NULL_PTR);
755 MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
756 auto dims = weight->shape();
757 if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
758 if (dims.size() <= 1) {
759 MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel";
760 weight_quant_type = FIXED_BIT_PER_LAYER;
761 }
762 }
763 std::vector<schema::QuantParamT> quant_params;
764 size_t elem_count = weight->DataSize();
765 auto *raw_data = static_cast<float *>(weight->data_c());
766 if (raw_data == nullptr) {
767 MS_LOG(ERROR) << "rawDatas is nullptr";
768 return RET_ERROR;
769 }
770
771 std::vector<int16_t> quant_data(elem_count);
772 int ret = RET_OK;
773 if (weight_quant_type == MIXED_BIT_PER_LAYER) {
774 MixedBitWeightQuantizer quantizer(init_scale);
775 quantizer.DoQuantization(static_cast<float *>(weight->data_c()), weight->shape_c(), 0, &quant_params, &quant_data);
776 } else {
777 MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
778 }
779 auto status =
780 UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(int16_t), quant_data_type);
781 if (status != RET_OK) {
782 MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
783 return RET_ERROR;
784 }
785
786 if (quant_params.empty()) {
787 MS_LOG(ERROR) << "quant_params empty";
788 return RET_ERROR;
789 }
790 auto quant_param_holder = GetCNodeQuantHolder(primitive);
791 quant_param_holder->set_input_quant_param(index, quant_params);
792 return ret;
793 }
CheckNodeInSet(const CNodePtr & cnode,const std::set<PrimitivePtr> & support_primitive_types)794 bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types) {
795 for (const auto &type : support_primitive_types) {
796 if (opt::CheckPrimitiveType(cnode, type)) {
797 return true;
798 }
799 }
800 return false;
801 }
802 } // namespace mindspore::lite::quant
803