• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
18 #include <cmath>
19 #include <string>
20 #include <map>
21 #include <fstream>
22 #include <algorithm>
23 #include <memory>
24 #include <vector>
25 #include <set>
26 #include <functional>
27 #include "include/version.h"
28 #include "ops/fusion/conv2d_fusion.h"
29 #include "ops/fusion/conv2d_transpose_fusion.h"
30 #include "ops/fusion/full_connection.h"
31 #include "ops/mat_mul.h"
32 #include "tools/converter/ops/ops_def.h"
33 #include "tools/anf_exporter/anf_exporter.h"
34 #include "tools/converter/quantizer/bitpacking.h"
35 #include "src/common/utils.h"
36 #include "tools/common/tensor_util.h"
37 #include "abstract/abstract_value.h"
38 #include "securec/include/securec.h"
39 #include "tools/optimizer/common/gllo_utils.h"
40 #include "tools/optimizer/common/format_utils.h"
41 
42 using std::string;
43 using std::vector;
44 
45 namespace mindspore::lite::quant {
46 const std::vector<std::string> QuantStrategy::conv_types_ = {ops::kNameConv2DFusion, ops::kNameConv2dTransposeFusion};
47 const std::vector<std::string> QuantStrategy::mul_types_ = {ops::kNameMatMul, ops::kNameFullConnection};
48 constexpr int kDim2 = 2;
49 constexpr int kDim4 = 4;
50 
51 const int kLstmInputWeightIndex = 1;
52 const int kLstmStateWeightIndex = 2;
53 const int kLstmWeightShapeSize = 3;
54 const int kSingleDirBiasTensorSize = 4;
55 const int kLstmBiasShapeSize = 2;
56 const int kLstmBiasIndex = 3;
57 
QuantStrategy(size_t weight_size,size_t conv_weight_quant_channel_threshold)58 QuantStrategy::QuantStrategy(size_t weight_size, size_t conv_weight_quant_channel_threshold)
59     : m_weight_size_(weight_size), m_conv_weight_quant_channel_threshold_(conv_weight_quant_channel_threshold) {}
60 
CanConvOpQuantized(const CNodePtr & node) const61 bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
62   MS_CHECK_TRUE_RET(node != nullptr, false);
63   auto primitive_c = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(node->input(0));
64   if (primitive_c == nullptr) {
65     MS_LOG(ERROR) << "primitive_c is nullptr";
66     return false;
67   }
68   if (!IsContain(conv_types_, primitive_c->name())) {
69     return false;
70   }
71   if (node->size() < 3) {
72     return false;
73   }
74   auto inputNode = node->input(2);
75   if (!inputNode->isa<Parameter>()) {
76     return false;
77   }
78   auto paramNode = inputNode->cast<ParameterPtr>();
79   MS_ASSERT(paramNode != nullptr);
80   auto abstract_base = paramNode->abstract();
81   if (abstract_base == nullptr) {
82     return false;
83   }
84   if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
85     MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
86     return false;
87   }
88   auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
89   size_t shapeSize = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>());
90   if (shapeSize < m_weight_size_) {
91     MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
92     return false;
93   }
94   if (weight_shape[0] <= static_cast<int>(m_conv_weight_quant_channel_threshold_)) {
95     MS_LOG(INFO) << "channel less m_conv_weight_quant_channel_threshold_!" << weight_shape[0];
96     return false;
97   }
98   return true;
99 }
100 
CanOpFullQuantized(const AnfNodePtr & node)101 bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
102   MS_CHECK_TRUE_RET(node != nullptr, false);
103   if (!node->isa<mindspore::CNode>()) {
104     return false;
105   }
106   const auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
107   MS_ASSERT(cnode != nullptr);
108   auto type = NodePrimitiveType(cnode);
109   static const std::set<PrimitivePtr> support_int8_ops = {prim::kPrimAddFusion,      prim::kPrimActivation,
110                                                           prim::kPrimAvgPoolFusion,  prim::kPrimConcat,
111                                                           prim::kPrimConv2DFusion,   prim::kPrimConv2dTransposeFusion,
112                                                           prim::kPrimCrop,           prim::kPrimFullConnection,
113                                                           prim::kPrimGather,         prim::kPrimLayerNormFusion,
114                                                           prim::kPrimMatMul,         prim::kPrimMaxPoolFusion,
115                                                           prim::kPrimMulFusion,      prim::kPrimReshape,
116                                                           prim::kPrimSplit,          prim::kPrimTranspose,
117                                                           prim::kPrimReduceFusion,   prim::kPrimDivFusion,
118                                                           prim::kPrimSqrt,           prim::kPrimPowFusion,
119                                                           prim::kPrimSubFusion,      prim::kPrimUnsqueeze,
120                                                           prim::kPrimLayerNormFusion};
121   // The return node does not need to be quantified.
122   if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn) || opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
123     return false;
124   }
125   // These operators do not need to check the data type.
126   if (opt::CheckPrimitiveType(cnode, prim::kPrimShape) || opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
127     return true;
128   }
129   auto is_support_node = CheckNodeInSet(cnode, support_int8_ops);
130   if (!is_support_node && type != "Eltwise") {
131     MS_LOG(WARNING) << "node:" << cnode->fullname_with_scope() << " type:" << type << " is not support quantization.";
132     return false;
133   }
134   TypeId type_id;
135   auto ret = opt::GetDataTypeFromAnfNode(cnode, &type_id);
136   if (ret != RET_OK) {
137     MS_LOG(ERROR) << "Fetch DataType from cnode failed.";
138     return ret;
139   }
140 
141   bool is_data_type_fp32 = type_id == kNumberTypeFloat32;
142   if (!is_data_type_fp32) {
143     MS_LOG(INFO) << cnode->fullname_with_scope() << "  type_id is " << type_id << " , and is not float32.";
144   }
145   return is_data_type_fp32;
146 }
147 
CanMulOpQuantized(const CNodePtr & node) const148 bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
149   MS_CHECK_TRUE_RET(node != nullptr, false);
150   auto primitive_c = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(node->input(0));
151   if (primitive_c == nullptr) {
152     MS_LOG(ERROR) << "primitive_c is nullptr";
153     return false;
154   }
155 
156   if (!IsContain(mul_types_, primitive_c->name())) {
157     return false;
158   }
159 
160   if (node->size() < 3) {
161     MS_LOG(INFO) << node->fullname_with_scope() << " input size less!";
162     return false;
163   }
164 
165   auto inputNode1 = node->input(1);
166   auto inputNode2 = node->input(2);
167   if (inputNode1 == nullptr || inputNode2 == nullptr) {
168     MS_LOG(INFO) << node->fullname_with_scope() << " mul input is nullptr!";
169     return false;
170   }
171 
172   ParameterPtr paramNode = nullptr;
173   if (inputNode1->isa<Parameter>()) {
174     paramNode = inputNode1->cast<ParameterPtr>();
175   } else if (inputNode2->isa<Parameter>()) {
176     paramNode = inputNode2->cast<ParameterPtr>();
177   }
178   if (paramNode == nullptr) {
179     MS_LOG(INFO) << node->fullname_with_scope() << " invalid paramNode!";
180     return false;
181   }
182 
183   auto abstract_base = paramNode->abstract();
184   if (abstract_base == nullptr) {
185     MS_LOG(INFO) << "abstract is nullptr";
186     return false;
187   }
188 
189   if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
190     MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
191     return false;
192   }
193   auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
194   size_t shapeSize = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>());
195   if (shapeSize < m_weight_size_) {
196     MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
197     return false;
198   }
199   return true;
200 }
201 
CanTensorQuantized(const AnfNodePtr & inputNode) const202 bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &inputNode) const {
203   if (inputNode == nullptr) {
204     MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
205     return false;
206   }
207   ParameterPtr paramNode = nullptr;
208   if (inputNode->isa<Parameter>()) {
209     paramNode = inputNode->cast<ParameterPtr>();
210   }
211   if (paramNode == nullptr) {
212     MS_LOG(INFO) << "CanTensorQuantized invalid paramNode!";
213     return false;
214   }
215   auto abstract_base = paramNode->abstract();
216   if (abstract_base == nullptr) {
217     MS_LOG(INFO) << "abstract is nullptr";
218     return false;
219   }
220   if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
221     MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
222     return false;
223   }
224   auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
225   MS_ASSERT(weight_shape != nullptr);
226   if (weight_shape.size() < kDim2) {  // do not quant single dim tensors
227     return false;
228   }
229   size_t shapeSize = std::accumulate(weight_shape.begin(), weight_shape.end(), 1, std::multiplies<int>());
230   if (shapeSize < m_weight_size_) {
231     MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
232     return false;
233   }
234   if (weight_shape.size() == kDim4) {  // assume Convolution
235     if (weight_shape[0] <= static_cast<int>(m_conv_weight_quant_channel_threshold_)) {
236       MS_LOG(INFO) << "channel less m_conv_weight_quant_channel_threshold_!" << weight_shape[0];
237       return false;
238     }
239   }
240 
241   return true;
242 }
243 
GetCNodeQuantHolder(const PrimitivePtr & primitive)244 QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
245   MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
246   QuantParamHolderPtr quant_params_holder = nullptr;
247   auto quant_params_valueptr = primitive->GetAttr("quant_params");
248   if (quant_params_valueptr == nullptr) {
249     quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
250     MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
251     primitive->AddAttr("quant_params", quant_params_holder);
252   } else {
253     quant_params_holder = quant_params_valueptr->cast<QuantParamHolderPtr>();
254     if (quant_params_holder == nullptr) {
255       quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
256       MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
257       primitive->AddAttr("quant_params", quant_params_holder);
258     }
259   }
260   return quant_params_holder;
261 }
262 
TensorQuantParamsInited(const schema::TensorT & tensor)263 bool TensorQuantParamsInited(const schema::TensorT &tensor) {
264   if (tensor.quantParams.empty()) {
265     return false;
266   }
267 
268   for (auto &quant_param : tensor.quantParams) {
269     if (!quant_param->inited) {
270       return false;
271     }
272   }
273   return true;
274 }
275 
CalQuantizationParams(schema::QuantParamT * quantParam,double mMin,double mMax,bool narrowRange,int numBits)276 STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int numBits) {
277   MS_ASSERT(quantParam != nullptr);
278   if (mMin > 0.0f) {
279     MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
280     mMin = 0.0f;
281   }
282   if (mMax < 0.0f) {
283     MS_LOG(DEBUG) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
284     mMax = 0.0f;
285   }
286   if (mMin > mMax) {
287     MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
288     return RET_PARAM_INVALID;
289   }
290   if (mMin == mMax) {
291     if (mMin != 0.0f) {
292       MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
293       return RET_ERROR;
294     }
295     quantParam->inited = true;
296     quantParam->min = mMin;
297     quantParam->max = mMax;
298     quantParam->scale = 0.0f;
299     quantParam->zeroPoint = 0;
300     quantParam->narrowRange = narrowRange;
301     quantParam->numBits = numBits;
302     return RET_OK;
303   }
304 
305   const int8_t quantMax = (1 << (static_cast<unsigned int>(numBits - 1))) - 1;
306   const int8_t quantMin = -1 * (1 << (static_cast<unsigned int>(numBits - 1))) + (narrowRange ? 1 : 0);
307   auto quantMinFloat = static_cast<double>(quantMin);
308   auto quantMaxFloat = static_cast<double>(quantMax);
309   if (fabs(quantMaxFloat - quantMinFloat) <= 0.0f) {
310     MS_LOG(ERROR) << "divisor cannot be 0";
311     return RET_ERROR;
312   }
313   double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
314   if (fabs(scale) <= 0.0f) {
315     MS_LOG(ERROR) << "divisor 'scale' cannot be 0";
316     return RET_ERROR;
317   }
318   const double zeroPointFromMin = quantMinFloat - mMin / scale;
319   const double zeroPointFromMax = quantMaxFloat - mMax / scale;
320   const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale);
321   const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale);
322   const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax;
323   int zeroPoint;
324   if (zpDouble < quantMinFloat) {
325     zeroPoint = quantMin;
326   } else if (zpDouble > quantMaxFloat) {
327     zeroPoint = quantMax;
328   } else {
329     zeroPoint = static_cast<int32_t>(std::round(zpDouble));
330   }
331   if (std::abs(mMin) == std::abs(mMax)) {
332     zeroPoint = 0;
333   }
334   // The zero point should always be in the range of quantized value,
335   // [qmin, qmax].
336   MS_ASSERT(zeroPoint >= quantMin);
337   MS_ASSERT(zeroPoint <= quantMax);
338   quantParam->inited = true;
339   quantParam->min = mMin;
340   quantParam->max = mMax;
341   quantParam->scale = scale;
342   quantParam->zeroPoint = zeroPoint;
343   quantParam->narrowRange = narrowRange;
344   quantParam->numBits = numBits;
345 
346   return RET_OK;
347 }
348 
SearchLowerBound(const std::vector<float> & data,const size_t & index,const float & max_tmp,float * min_tmp,size_t * min_idx)349 static bool SearchLowerBound(const std::vector<float> &data, const size_t &index, const float &max_tmp, float *min_tmp,
350                              size_t *min_idx) {
351   MS_ASSERT(!data.empty());
352   size_t length = data.size();
353   if (max_tmp - data.at(index) < delta) {
354     return false;
355   }
356   if (fabs(max_tmp - *min_tmp) <= 0.0f || fabs(length - *min_idx) <= 0.0f) {
357     MS_LOG(INFO) << "divisor cannot be 0";
358     return false;
359   }
360   float range_ratio = (data.at(index) - *min_tmp) / (max_tmp - *min_tmp);
361   float index_ratio = static_cast<float>(index - *min_idx) / (length - *min_idx);
362   if (fabs(index_ratio) <= 0.0f) {
363     MS_LOG(INFO) << "divisor cannot be 0";
364     return false;
365   }
366   if (index_ratio > 0 && range_ratio / index_ratio > ratio) {
367     *min_idx = index;
368     *min_tmp = data.at(index);
369   }
370   return true;
371 }
372 
SearchUpperBound(const std::vector<float> & data,const size_t & index,float * max_tmp,const float & min_tmp,size_t * max_idx)373 static bool SearchUpperBound(const std::vector<float> &data, const size_t &index, float *max_tmp, const float &min_tmp,
374                              size_t *max_idx) {
375   MS_ASSERT(!data.empty());
376   size_t length = data.size();
377   if (data.at(index) - min_tmp < delta) {
378     return false;
379   }
380   if (fabs(*max_tmp - min_tmp) <= 0.0f || fabs(length - *max_idx) <= 0.0f) {
381     MS_LOG(INFO) << "divisor cannot be 0";
382     return false;
383   }
384   float range_ratio = (*max_tmp - data.at(index)) / (*max_tmp - min_tmp);
385   float index_ratio = static_cast<float>(index - *max_idx) / (length - *max_idx);
386   if (fabs(index_ratio) <= 0.0f) {
387     MS_LOG(INFO) << "divisor cannot be 0";
388     return false;
389   }
390   if (index_ratio > 0 && range_ratio / index_ratio > ratio) {
391     *max_idx = index;
392     *max_tmp = data.at(index);
393   }
394   return true;
395 }
396 
CalPercentile(const std::vector<float> & data,const int & outlier_percent)397 static float CalPercentile(const std::vector<float> &data, const int &outlier_percent) {
398   MS_ASSERT(!data.empty());
399   const int size = data.size();
400   float val = outlier_percent / kPercentBase * size;
401   int index = std::ceil(val);
402   float result;
403   if (index - val > 0) {
404     MS_ASSERT(index - 1 >= 0);
405     result = data.at(index - 1);
406   } else {
407     MS_ASSERT(index - 1 >= 0);
408     result = (data.at(index - 1) + data.at(index)) / 2;
409   }
410   return result;
411 }
412 
OutlierMethod(std::vector<float> min_datas,std::vector<float> max_datas)413 std::pair<float, float> OutlierMethod(std::vector<float> min_datas, std::vector<float> max_datas) {
414   MS_ASSERT(!min_datas.empty());
415   MS_ASSERT(!max_datas.empty());
416   std::sort(max_datas.begin(), max_datas.end());
417   std::sort(min_datas.begin(), min_datas.end());
418   float min_val = CalPercentile(min_datas, percent);
419   float max_val = CalPercentile(max_datas, kPercentBase - percent);
420   std::reverse(max_datas.begin(), max_datas.end());
421   MS_ASSERT(min_val < max_val);
422   MS_ASSERT(min_datas.size() == max_datas.size());
423   float min_tmp = min_val;
424   float max_tmp = max_val;
425   size_t min_idx = 0;
426   size_t max_idx = 0;
427   size_t length = min_datas.size();
428   for (size_t i = 0; i < length; i++) {
429     if (!SearchLowerBound(min_datas, i, max_tmp, &min_tmp, &min_idx)) {
430       break;
431     }
432     if (!SearchUpperBound(min_datas, i, &max_tmp, min_tmp, &max_idx)) {
433       break;
434     }
435   }
436   std::pair<float, float> result{min_tmp, max_tmp};
437   return result;
438 }
439 
InitClusters(float * data,size_t elem_count,size_t k)440 static std::vector<float> InitClusters(float *data, size_t elem_count, size_t k) {
441   MS_ASSERT(data != nullptr);
442   std::set<float> set_unique{};
443   for (size_t i = 0; i < elem_count; i++) {
444     set_unique.emplace(data[i]);
445   }
446   std::vector<float> data_unique;
447   data_unique.assign(set_unique.begin(), set_unique.end());
448   std::vector<float> clusters{};
449   if (set_unique.size() < k) {
450     return clusters;
451   }
452   // init cluster
453   MS_ASSERT(k != 1);
454   float cluster_ratio = static_cast<float>(data_unique.size()) / (k - 1);
455   std::sort(data_unique.begin(), data_unique.end());
456   for (size_t i = 0; i < k; i++) {
457     size_t index = std::floor(i * cluster_ratio);
458     if (i * cluster_ratio - index > 0) {
459       clusters.emplace_back((data_unique[index] + data_unique[index + 1]) / 2);
460     } else {
461       clusters.emplace_back(data_unique[index]);
462     }
463   }
464   return clusters;
465 }
466 
KMeans(float * data,size_t elem_count,size_t k,size_t epochs,schema::QuantParamT * quantParam)467 std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) {
468   MS_ASSERT(data != nullptr);
469   MS_CHECK_TRUE_MSG(elem_count != 0, std::vector<int8_t>{}, "elem_count is zero.");
470   std::vector<float> clusters = InitClusters(data, elem_count, k);
471   std::vector<int8_t> clusters_index{};
472   double error{0};
473   if (clusters.size() < k) {
474     MS_LOG(WARNING) << "K is less than the size of data so KMeans function is not executed.";
475     return clusters_index;
476   }
477   for (size_t epoch = 0; epoch < epochs; epoch++) {
478     double error_cur{0};
479     clusters_index.clear();
480     std::vector<std::vector<float>> clusters_data(clusters.size());
481     for (size_t i = 0; i < elem_count; i++) {
482       size_t index = 0;
483       float min_distance = pow(data[i] - clusters[0], 2);
484       for (size_t j = 1; j < clusters.size(); j++) {
485         if (pow(data[i] - clusters[j], 2) < min_distance) {
486           min_distance = pow(data[i] - clusters[j], 2);
487           index = j;
488         }
489       }
490       clusters_index.emplace_back(index + INT8_MIN);
491       clusters_data[index].emplace_back(data[i]);
492     }
493     for (size_t j = 0; j < clusters.size(); j++) {
494       if (!clusters_data[j].empty()) {
495         clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size();
496       }
497     }
498     // compare error
499     for (size_t j = 0; j < elem_count; j++) {
500       error_cur += pow(data[j] - clusters[clusters_index[j]], 2);
501     }
502     error_cur = pow(error_cur / elem_count, 0.5);
503     if (std::abs((error_cur - error) / error_cur) <= 0.0f) {
504       break;
505     }
506     error = error_cur;
507   }
508   // update data
509   return clusters_index;
510 }
511 
NodePrimitiveType(const CNodePtr & cnode)512 std::string NodePrimitiveType(const CNodePtr &cnode) {
513   if (cnode == nullptr) {
514     MS_LOG(ERROR) << "cnode is null";
515     return "";
516   }
517   auto primitive_c = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
518   if (primitive_c == nullptr) {
519     MS_LOG(ERROR) << "primitive_c is null";
520     return "";
521   }
522   return primitive_c->name();
523 }
524 
CreateSessionByFuncGraph(const FuncGraphPtr & func_graph,const converter::Flags & flags,int thread_num)525 SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num) {
526   SessionModel sm;
527   auto meta_graph = Export(func_graph, true, true);
528   if (meta_graph == nullptr) {
529     MS_LOG(ERROR) << "Export to meta_graph failed";
530     return sm;
531   }
532 
533   // transform
534   GraphDefTransform fb_transform;
535   fb_transform.SetGraphDef(meta_graph);
536   auto status = fb_transform.Transform(flags);
537   if (status != RET_OK) {
538     MS_LOG(ERROR) << "FBTransform model failed";
539     return sm;
540   }
541   meta_graph->version = Version();
542 
543   flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
544   auto offset = schema::MetaGraph::Pack(builder, meta_graph);
545   builder.Finish(offset);
546   schema::FinishMetaGraphBuffer(builder, offset);
547   auto size = builder.GetSize();
548   auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
549   if (content == nullptr) {
550     MS_LOG(ERROR) << "GetBufferPointer return null";
551     return sm;
552   }
553   auto model = lite::Model::Import(content, size);
554   if (model == nullptr) {
555     MS_LOG(ERROR) << "Import model failed";
556     return sm;
557   }
558   Context ctx;
559   ctx.thread_num_ = thread_num;
560   auto session = session::LiteSession::CreateSession(&ctx);
561   if (session == nullptr) {
562     MS_LOG(ERROR) << "create session failed.";
563     model->Free();
564     delete meta_graph;
565     return sm;
566   }
567 
568   status = session->CompileGraph(model);
569   if (status != RET_OK) {
570     MS_LOG(ERROR) << "CompileGraph error";
571     model->Free();
572     delete meta_graph;
573     delete session;
574     return sm;
575   }
576   model->Free();
577   delete meta_graph;
578   sm.session = session;
579   sm.model = model;
580   return sm;
581 }
582 
CopyFuncGraph(const FuncGraphPtr & func_graph)583 FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) {
584   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
585   Cloner cloner({func_graph}, true, true, true, std::make_shared<TraceCopy>(), nullptr);
586   auto new_func_graph = cloner[func_graph];
587 
588   std::map<std::string, CNodePtr> old_cnode_map;
589   for (const auto &cnode : func_graph->GetOrderedCnodes()) {
590     old_cnode_map[cnode->fullname_with_scope()] = cnode;
591   }
592 
593   for (auto &cnode : new_func_graph->GetOrderedCnodes()) {
594     auto cnode_name = cnode->fullname_with_scope();
595     auto old_cnode_iter = old_cnode_map.find(cnode_name);
596     if (old_cnode_iter == old_cnode_map.end()) {
597       MS_LOG(ERROR) << "can not find node: " << cnode_name;
598       return nullptr;
599     }
600     auto old_cnode = old_cnode_iter->second;
601     auto inputs = cnode->inputs();
602     for (const auto &input_node : inputs) {
603       if (input_node->isa<Parameter>()) {
604         auto param_node = input_node->cast<ParameterPtr>();
605         if (!param_node->has_default()) {
606           MS_LOG(ERROR) << "Param node has no default parameter: " << cnode_name;
607           return nullptr;
608         }
609         auto old_tensor_info = std::static_pointer_cast<tensor::Tensor>(param_node->default_param());
610         if (old_tensor_info == nullptr) {
611           MS_LOG(ERROR) << "Default param of param node is not a tensor info:" << cnode_name;
612           return nullptr;
613         }
614         auto new_tensor_info = lite::CreateTensorInfo(old_tensor_info->data().data(), old_tensor_info->data().nbytes(),
615                                                       old_tensor_info->shape(), old_tensor_info->data_type());
616         if (new_tensor_info == nullptr) {
617           MS_LOG(ERROR) << "Create tensor info failed";
618           return nullptr;
619         }
620         auto status = lite::InitParameterFromTensorInfo(param_node, new_tensor_info);
621         if (status != RET_OK) {
622           MS_LOG(ERROR) << "init parameter from tensor info failed";
623           return nullptr;
624         }
625       }
626     }  // end inputs loop
627   }    // end cnodes loop
628   return new_func_graph;
629 }
630 
GetLiteParameter(const AnfNodePtr & node,ParameterPtr * param_node,tensor::TensorPtr * tensor_info)631 void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
632   if (node == nullptr) {
633     MS_LOG(ERROR) << "node is nullptr";
634     return;
635   }
636   auto op_name = node->fullname_with_scope();
637 
638   *param_node = node->cast<ParameterPtr>();
639   if (*param_node == nullptr) {
640     MS_LOG(INFO) << op_name << " can not cast to ParameterPtr";
641     return;
642   }
643   if (!(*param_node)->has_default()) {
644     MS_LOG(INFO) << op_name << " not has_default";
645     return;
646   }
647 
648   *tensor_info = std::static_pointer_cast<tensor::Tensor>((*param_node)->default_param());
649   if (*tensor_info == nullptr) {
650     MS_LOG(INFO) << "default_param can not cast to tensor::Tensor";
651     return;
652   }
653 }
654 
UpdateTensorDataAndSize(const tensor::TensorPtr & weight,void * quant_datas,int new_size,TypeId new_data_type)655 STATUS UpdateTensorDataAndSize(const tensor::TensorPtr &weight, void *quant_datas, int new_size, TypeId new_data_type) {
656   MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
657   MS_CHECK_TRUE_RET(new_size > 0, RET_NULL_PTR);
658   weight->set_data_type(new_data_type);
659   if (new_size != weight->data().nbytes()) {
660     MS_LOG(ERROR) << "Data size of tensor info is error.";
661     return RET_ERROR;
662   }
663   if (memcpy_s(weight->data_c(), new_size, quant_datas, new_size) != EOK) {
664     MS_LOG(ERROR) << "memcpy data failed.";
665     return RET_ERROR;
666   }
667   return RET_OK;
668 }
669 
CalChannels(const ShapeVector & dims,int channel_cnt,bool * channel_at_first)670 int CalChannels(const ShapeVector &dims, int channel_cnt, bool *channel_at_first) {
671   auto channels = dims[0];
672   if (!(*channel_at_first)) {
673     if (dims.size() != 2) {
674       MS_LOG(WARNING) << "unexpected dims size: " << dims.size();
675       *channel_at_first = true;
676     } else {
677       channels = dims[1];
678     }
679   } else {
680     channels = channel_cnt == -1 ? channels : channel_cnt;
681   }
682   return channels;
683 }
684 
CalQuantAssitInfo(const PrimitivePtr & primitive,const ShapeVector & shapes,int index,bool * channel_at_first,int * channel_cnt)685 void CalQuantAssitInfo(const PrimitivePtr &primitive, const ShapeVector &shapes, int index, bool *channel_at_first,
686                        int *channel_cnt) {
687   MS_ASSERT(primitive != nullptr);
688   if (shapes.empty()) {
689     MS_LOG(ERROR) << " shape vector is empty.";
690     return;
691   }
692   if (primitive->name() == ops::kNameMatMul && static_cast<int>(shapes.size()) == DIMENSION_2D) {
693     auto matmul_prim = primitive->cast<std::shared_ptr<ops::MatMul>>();
694     MS_ASSERT(matmul_prim != nullptr);
695     *channel_at_first =
696       index != 1 || (matmul_prim->GetAttr(ops::kTransposeB) != nullptr && matmul_prim->get_transpose_b());
697   } else if (primitive->name() == ops::kNameLSTM) {
698     if (index == kLstmInputWeightIndex || index == kLstmStateWeightIndex) {
699       if (shapes.size() != kLstmWeightShapeSize) {
700         MS_LOG(WARNING) << "unexpected lstm shape size: " << shapes.size();
701       } else {
702         *channel_cnt = shapes[0] * shapes[1];
703       }
704     } else if (index == kLstmBiasIndex) {
705       if (shapes.size() != kLstmBiasShapeSize) {
706         MS_LOG(WARNING) << "unexpected lstm shape size: " << shapes.size();
707       } else {
708         auto tensor_elem_cnt = shapes[0] * shapes[1];
709         if (tensor_elem_cnt % kSingleDirBiasTensorSize == 0) {
710           *channel_cnt = kSingleDirBiasTensorSize;
711         }
712       }
713     } else {
714       MS_LOG(WARNING) << "unexpected index of lstm: " << index;
715     }
716   }
717 }
718 
CalQuantAssitInfo(const schema::PrimitiveT & primitive,const std::vector<int> & shapes,int index,bool * channel_at_first,int * channel_cnt)719 void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<int> &shapes, int index,
720                        bool *channel_at_first, int *channel_cnt) {
721   MS_ASSERT(primitive != nullptr);
722   if (shapes.empty()) {
723     MS_LOG(ERROR) << " shape vector is empty.";
724     return;
725   }
726   if (primitive.value.type == schema::PrimitiveType_MatMul && static_cast<int>(shapes.size()) == kDim2) {
727     auto matmul_prim = primitive.value.AsMatMul();
728     MS_ASSERT(matmul_prim != nullptr);
729     *channel_at_first = index != 1 || matmul_prim->transpose_b;
730   } else if (primitive.value.type == schema::PrimitiveType_LSTM) {
731     if (index == kLstmInputWeightIndex || index == kLstmStateWeightIndex) {
732       if (shapes.size() != kLstmWeightShapeSize) {
733         MS_LOG(WARNING) << "unexpected lstm shape size: " << shapes.size();
734       } else {
735         *channel_cnt = shapes[0] * shapes[1];
736       }
737     } else if (index == kLstmBiasIndex) {
738       if (shapes.size() != kLstmBiasShapeSize) {
739         MS_LOG(WARNING) << "unexpected lstm shape size: " << shapes.size();
740       } else {
741         auto tensor_elem_cnt = shapes[0] * shapes[1];
742         if (tensor_elem_cnt % kSingleDirBiasTensorSize == 0) {
743           *channel_cnt = kSingleDirBiasTensorSize;
744         }
745       }
746     } else {
747       MS_LOG(WARNING) << "unexpected index of lstm: " << index;
748     }
749   }
750 }
751 
MixedBitQuantFilter(const tensor::TensorPtr & weight,const PrimitivePtr & primitive,QuantType quant_type,WeightQuantType weight_quant_type,TypeId quant_data_type,double init_scale,int index)752 STATUS MixedBitQuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
753                            WeightQuantType weight_quant_type, TypeId quant_data_type, double init_scale, int index) {
754   MS_CHECK_TRUE_RET(primitive != nullptr, RET_NULL_PTR);
755   MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
756   auto dims = weight->shape();
757   if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
758     if (dims.size() <= 1) {
759       MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel";
760       weight_quant_type = FIXED_BIT_PER_LAYER;
761     }
762   }
763   std::vector<schema::QuantParamT> quant_params;
764   size_t elem_count = weight->DataSize();
765   auto *raw_data = static_cast<float *>(weight->data_c());
766   if (raw_data == nullptr) {
767     MS_LOG(ERROR) << "rawDatas is nullptr";
768     return RET_ERROR;
769   }
770 
771   std::vector<int16_t> quant_data(elem_count);
772   int ret = RET_OK;
773   if (weight_quant_type == MIXED_BIT_PER_LAYER) {
774     MixedBitWeightQuantizer quantizer(init_scale);
775     quantizer.DoQuantization(static_cast<float *>(weight->data_c()), weight->shape_c(), 0, &quant_params, &quant_data);
776   } else {
777     MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
778   }
779   auto status =
780     UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(int16_t), quant_data_type);
781   if (status != RET_OK) {
782     MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
783     return RET_ERROR;
784   }
785 
786   if (quant_params.empty()) {
787     MS_LOG(ERROR) << "quant_params empty";
788     return RET_ERROR;
789   }
790   auto quant_param_holder = GetCNodeQuantHolder(primitive);
791   quant_param_holder->set_input_quant_param(index, quant_params);
792   return ret;
793 }
CheckNodeInSet(const CNodePtr & cnode,const std::set<PrimitivePtr> & support_primitive_types)794 bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types) {
795   for (const auto &type : support_primitive_types) {
796     if (opt::CheckPrimitiveType(cnode, type)) {
797       return true;
798     }
799   }
800   return false;
801 }
802 }  // namespace mindspore::lite::quant
803