• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/train/train_export.h"
17 #include <unistd.h>
18 #include <sys/stat.h>
19 #include <fstream>
20 #include <utility>
21 #include <queue>
22 #include <algorithm>
23 #include <functional>
24 #include <map>
25 #include <set>
26 #include "schema/inner/model_generated.h"
27 #include "src/train/train_utils.h"
28 #include "src/common/quant_utils.h"
29 #include "src/common/storage.h"
30 #include "src/train/graph_fusion.h"
31 #include "src/train/graph_dropout.h"
32 #include "src/litert/weight_decoder.h"
33 #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h"
34 #include "base/float16.h"
35 
36 namespace mindspore {
37 namespace lite {
38 namespace {
39 constexpr static int kFmkVal = 3;
40 constexpr static int kTransformTensorDim = 4;
GetLinkedPostIdx(const schema::MetaGraphT & graphT,const size_t & tensorIdx)41 std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
42   std::vector<size_t> postNodeIdx;
43   for (size_t i = 0; i < graphT.nodes.size(); i++) {
44     auto &oldNode = graphT.nodes.at(i);
45     if (oldNode == nullptr) {
46       continue;
47     }
48     auto inputIndexes = oldNode->inputIndex;
49     if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
50       postNodeIdx.emplace_back(i);
51     }
52   }
53   return postNodeIdx;
54 }
55 
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const schema::CNodeT & node,const int outputIndexIdx=-1)56 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
57                                      const int outputIndexIdx = -1) {
58   std::vector<uint32_t> outputIndexes;
59   if (outputIndexIdx == -1) {
60     outputIndexes = node.outputIndex;
61   } else {
62     outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
63   }
64   std::set<size_t> outputNodeIdx;
65   for (uint32_t outputIdx : outputIndexes) {
66     auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
67     outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
68   }
69   std::vector<size_t> ret;
70   ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
71   return ret;
72 }
73 }  // namespace
74 
CreateData(const lite::Tensor * tensor)75 std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) {
76   uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data());
77   auto size = tensor->Size();
78   std::vector<uint8_t> data(tensor_data, tensor_data + size);
79   return data;
80 }
81 
NeedQuantization(const lite::Tensor * t,const int tensor_quant_type)82 bool TrainExport::NeedQuantization(const lite::Tensor *t, const int tensor_quant_type) {
83   return ((quant_type_ == QT_WEIGHT && t->shape().size() > 1) ||
84           ((quant_type_ == QT_DEFAULT) && (tensor_quant_type == schema::QuantType_QUANT_WEIGHT)));
85 }
86 
GetNodeQuantType(const mindspore::kernel::KernelExec * kernel)87 schema::QuantType TrainExport::GetNodeQuantType(const mindspore::kernel::KernelExec *kernel) {
88   return static_cast<schema::QuantType>(kernel->op_parameter()->quant_type_);
89 }
90 
TagQuantizedNodes()91 void TrainExport::TagQuantizedNodes() {
92   if (quant_type_ == QT_WEIGHT) {
93     for (auto &node : meta_graph_->nodes) {
94       if (node->quantType != schema::QuantType_QUANT_WEIGHT) {
95         for (auto t_idx : node->inputIndex) {
96           if ((meta_graph_->allTensors.at(t_idx)->nodeType == NodeType_ValueNode) &&
97               (meta_graph_->allTensors.at(t_idx)->quantParams.size() > 0)) {
98             node->quantType = schema::QuantType_QUANT_WEIGHT;
99           }
100         }
101       }
102     }
103   }
104 }
105 
QuantTensorData(schema::TensorT * dest_tensor,const lite::Tensor * src_tensor,int preferred_dim)106 int TrainExport::QuantTensorData(schema::TensorT *dest_tensor, const lite::Tensor *src_tensor, int preferred_dim) {
107   int channels = 1;
108   int bit_num = 8;
109 
110   if (src_tensor->quant_params().size() > 0) {
111     channels = src_tensor->quant_params().size();
112     bit_num = src_tensor->quant_params().at(0).bitNum;
113   }
114   if (channels < 1) {
115     MS_LOG(ERROR) << "Quant Params is empty";
116     return RET_ERROR;
117   }
118   int quant_max = QuantMax(bit_num, false);
119   int quant_min = QuantMin(bit_num, false);
120   std::vector<int8_t> data(src_tensor->ElementsNum());
121   std::vector<schema::QuantParamT> quant_params;
122 
123   STATUS ret = RET_OK;
124   if (channels == kPerTensor) {
125     ret = DoPerLayerQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
126                                   &(quant_params), quant_max, quant_min, bit_num, &data, false, false);
127   } else {
128     ret = DoPerChannelQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
129                                     &(quant_params), quant_max, quant_min, bit_num, &data, dest_tensor->dims,
130                                     preferred_dim, true, false, false);
131   }
132   if (ret == RET_NO_CHANGE) {
133     MS_LOG(DEBUG) << "No Need to quant per channel";
134     return RET_OK;
135   }
136   if (ret == RET_ERROR) {
137     MS_LOG(ERROR) << "QuantTensorData error,  channels = " << channels;
138     return ret;
139   }
140   if (quant_params.empty()) {
141     MS_LOG(ERROR) << "quant_params empty";
142     return RET_ERROR;
143   }
144   dest_tensor->data = std::vector<uint8_t>(data.data(), data.data() + data.size());
145   dest_tensor->dataType = kNumberTypeInt8;
146   dest_tensor->quantParams.clear();
147   for (auto quant_param : quant_params) {
148     dest_tensor->quantParams.emplace_back(std::make_unique<schema::QuantParamT>(quant_param));
149   }
150 
151   return RET_OK;
152 }
153 
CreateTensor(const mindspore::lite::Tensor * tensor,const std::vector<mindspore::lite::Tensor * > const_folded_output,schema::Tensor * scTensor,int preferred_dim,const int tensor_quant_type)154 std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(
155   const mindspore::lite::Tensor *tensor, const std::vector<mindspore::lite::Tensor *> const_folded_output,
156   schema::Tensor *scTensor, int preferred_dim, const int tensor_quant_type) {
157   MS_CHECK_TRUE_RET(tensor != nullptr, nullptr);
158   MS_CHECK_TRUE_RET(scTensor != nullptr, nullptr);
159   auto tensorT = std::make_unique<schema::TensorT>();
160   bool const_fold = false;
161   if (quant_type_ == QT_NONE && !const_folded_output.empty() &&
162       std::find(const_folded_output.begin(), const_folded_output.end(), tensor) != const_folded_output.end()) {
163     tensorT->nodeType = NodeType_ValueNode;
164     const_fold = true;
165   } else {
166     tensorT->nodeType = scTensor->nodeType();
167   }
168   tensorT->dims = tensor->shape();
169   tensorT->format = static_cast<schema::Format>(tensor->format());
170   tensorT->name = tensor->tensor_name();
171   tensorT->refCount = 0;
172   tensorT->offset = 0;
173   tensorT->dataType = tensor->data_type();
174   tensorT->enableHuffmanCode = false;
175   if (((tensorT->nodeType == NodeType_ValueNode) && (scTensor->data() != nullptr) && (scTensor->data()->size() > 0)) ||
176       const_fold) {
177     if (NeedQuantization(tensor, tensor_quant_type)) {
178       auto ret = QuantTensorData(tensorT.get(), tensor, preferred_dim);
179       if (ret != RET_OK) {
180         MS_LOG(ERROR) << "QuantTensorData failed.";
181         return nullptr;
182       }
183     } else {
184       tensorT->data = CreateData(tensor);
185     }
186   }
187   tensorT->quantClusters = tensor->quant_clusters();
188   return tensorT;
189 }
190 
FindNode(const mindspore::kernel::KernelExec * kernel,const Model * model)191 LiteGraph::Node *TrainExport::FindNode(const mindspore::kernel::KernelExec *kernel, const Model *model) {
192   auto nodes = model->graph_.all_nodes_;
193   auto it = std::find_if(nodes.begin(), nodes.end(),
194                          [&kernel](mindspore::lite::LiteGraph::Node *n) { return (kernel->name() == n->name_); });
195   if (it == nodes.end()) {
196     return nullptr;
197   }
198   return *it;
199 }
200 
CreateAndAddCNode(const mindspore::kernel::KernelExec * kernel,std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,const Model * model)201 int TrainExport::CreateAndAddCNode(const mindspore::kernel::KernelExec *kernel, std::vector<uint32_t> inputIndex,
202                                    std::vector<uint32_t> outputIndex, const Model *model) {
203   auto cnode = CreateCNode(kernel, inputIndex, outputIndex, model);
204   if (cnode == nullptr) {
205     MS_LOG(ERROR) << "failed to create cnode";
206     return RET_ERROR;
207   }
208   meta_graph_->nodes.emplace_back(std::move(cnode));
209   if (!meta_graph_->subGraph.empty()) {
210     meta_graph_->subGraph[0]->nodeIndices.push_back(meta_graph_->nodes.size() - 1);
211   }
212   return RET_OK;
213 }
214 
CreateCNode(const mindspore::kernel::KernelExec * kernel,std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,const Model * model)215 std::unique_ptr<schema::CNodeT> TrainExport::CreateCNode(const mindspore::kernel::KernelExec *kernel,
216                                                          std::vector<uint32_t> inputIndex,
217                                                          std::vector<uint32_t> outputIndex, const Model *model) {
218   auto cnodeT = std::make_unique<schema::CNodeT>();
219   if (cnodeT == nullptr) {
220     MS_LOG(ERROR) << " cannot allocate node";
221     return nullptr;
222   }
223   cnodeT->inputIndex = inputIndex;
224   cnodeT->outputIndex = outputIndex;
225   cnodeT->name = kernel->name();
226   cnodeT->quantType = GetNodeQuantType(kernel);
227   // find kernel in model
228   auto *node = FindNode(kernel, model);
229   if (node == nullptr) {
230     MS_LOG(ERROR) << "cannot find kernel " + kernel->name() + " in model";
231     return nullptr;
232   }
233   auto primitive = reinterpret_cast<schema::Primitive *>(const_cast<void *>(node->primitive_));
234   cnodeT->primitive = std::unique_ptr<schema::PrimitiveT>(primitive->UnPack());
235   return cnodeT;
236 }
237 
LoadModel(void * buf,size_t buf_size)238 int TrainExport::LoadModel(void *buf, size_t buf_size) {
239   flatbuffers::Verifier verify((const uint8_t *)buf, buf_size);
240   if (!schema::VerifyMetaGraphBuffer(verify)) {
241     MS_LOG(ERROR) << "model flatbuffer verify fail";
242     return RET_ERROR;
243   }
244   meta_graph_ = schema::GetMetaGraph(buf)->UnPack();
245   meta_graph_->outputIndex.clear();
246   if (!meta_graph_->subGraph.empty()) {
247     meta_graph_->subGraph[0]->outputIndices.clear();
248   }
249   return RET_OK;
250 }
251 
CreateTransformTensor(size_t id)252 std::unique_ptr<schema::TensorT> TrainExport::CreateTransformTensor(size_t id) {
253   auto &scTensor = meta_graph_->allTensors.at(id);
254   auto tensorT = std::make_unique<schema::TensorT>();
255   if (tensorT == nullptr) {
256     MS_LOG(ERROR) << "Could not create tensor ";
257     return nullptr;
258   }
259   tensorT->nodeType = scTensor->nodeType;
260   tensorT->dataType = scTensor->dataType;
261   std::vector<int32_t> dims;
262   std::vector<int32_t> val = {0, 2, 3, 1};
263   if (scTensor->dims.size() == kTransformTensorDim) {
264     for (size_t i = 0; i < val.size(); i++) {
265       dims.push_back(scTensor->dims.at(val[i]));
266     }
267     tensorT->dims = dims;
268   } else {
269     tensorT->dims = scTensor->dims;
270   }
271   tensorT->format = schema::Format_NHWC;
272   tensorT->name = scTensor->name + "_post";
273   tensorT->refCount = 0;
274   tensorT->offset = 0;
275   tensorT->enableHuffmanCode = false;
276   return tensorT;
277 }
278 
CreateTransformConst(size_t last_id)279 std::unique_ptr<schema::TensorT> TrainExport::CreateTransformConst(size_t last_id) {
280   auto tensorT = std::make_unique<schema::TensorT>();
281   if (tensorT == nullptr) {
282     MS_LOG(ERROR) << "Could not create tensor ";
283     return nullptr;
284   }
285   tensorT->nodeType = lite::NodeType_ValueNode;
286   tensorT->dataType = TypeId::kNumberTypeInt32;
287   tensorT->dims = {kTransformTensorDim};
288   tensorT->format = schema::Format_NCHW;
289   tensorT->name = "const-" + std::to_string(last_id);
290   tensorT->refCount = 0;
291   tensorT->offset = 0;
292   tensorT->enableHuffmanCode = false;
293   int32_t val[] = {0, 2, 3, 1};
294   uint8_t *valp = reinterpret_cast<uint8_t *>(val);
295   tensorT->data = std::vector<uint8_t>(valp, valp + sizeof(val));
296   return tensorT;
297 }
298 
CreateTransformNode(std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,size_t id)299 std::unique_ptr<schema::CNodeT> TrainExport::CreateTransformNode(std::vector<uint32_t> inputIndex,
300                                                                  std::vector<uint32_t> outputIndex, size_t id) {
301   auto cnodeT = std::make_unique<schema::CNodeT>();
302   if (cnodeT == nullptr) {
303     MS_LOG(ERROR) << "cannot allocate node";
304     return nullptr;
305   }
306   cnodeT->inputIndex = inputIndex;
307   cnodeT->outputIndex = outputIndex;
308   cnodeT->name = "transpose-" + std::to_string(id);
309   cnodeT->quantType = schema::QuantType_QUANT_NONE;
310   cnodeT->primitive = std::make_unique<schema::PrimitiveT>();
311   cnodeT->primitive->value.type = schema::PrimitiveType_Transpose;
312   return cnodeT;
313 }
314 
AddTransformNode()315 int TrainExport::AddTransformNode() {
316   std::unordered_map<size_t, size_t> reconnect;
317   size_t last_id = meta_graph_->allTensors.size();
318   size_t last_node = meta_graph_->nodes.size();
319   for (auto it : connect_) {
320     auto tensorConst = CreateTransformConst(last_id);
321     if (tensorConst == nullptr) {
322       MS_LOG(ERROR) << "error in create tensor";
323       return RET_ERROR;
324     }
325     meta_graph_->allTensors.emplace_back(std::move(tensorConst));  // last_id
326     if (!meta_graph_->subGraph.empty()) {
327       meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
328     }
329     auto tensorT = CreateTransformTensor(it.second);
330     if (tensorT == nullptr) {
331       MS_LOG(ERROR) << "error in create tensor";
332       return RET_ERROR;
333     }
334     meta_graph_->allTensors.emplace_back(std::move(tensorT));  // last_id + 1
335     if (!meta_graph_->subGraph.empty()) {
336       meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
337     }
338     std::vector<uint32_t> in_idx = {static_cast<uint32_t>(it.second), static_cast<uint32_t>(last_id)};
339     std::vector<uint32_t> out_idx = {static_cast<uint32_t>(last_id + 1)};
340     reconnect[it.first] = last_id + 1;
341     auto cnode = CreateTransformNode(in_idx, out_idx, last_node);
342     if (cnode == nullptr) {
343       MS_LOG(ERROR) << "error in node creation";
344       return RET_ERROR;
345     }
346     meta_graph_->nodes.emplace_back(std::move(cnode));
347     if (!meta_graph_->subGraph.empty()) {
348       meta_graph_->subGraph[0]->nodeIndices.push_back(meta_graph_->nodes.size() - 1);
349     }
350   }
351   connect_ = reconnect;
352   return RET_OK;
353 }
354 
PrepareRemap(int offset)355 void TrainExport::PrepareRemap(int offset) {
356   for (auto it : connect_) {
357     remap_[it.first + offset] = it.second;
358   }
359 }
360 
FindSchemaTensorByName(const std::vector<uint32_t> & search_indices,const std::string & search_name,size_t * target_index)361 int TrainExport::FindSchemaTensorByName(const std::vector<uint32_t> &search_indices, const std::string &search_name,
362                                         size_t *target_index) {
363   MS_CHECK_TRUE_MSG(target_index != nullptr, RET_ERROR, "input param target_index is nullptr.");
364   auto total_size = meta_graph_->allTensors.size();
365   for (auto index : search_indices) {
366     MS_CHECK_TRUE_MSG(index < total_size, RET_ERROR, "index is out of range.");
367     if (meta_graph_->allTensors[index]->name == search_name) {
368       *target_index = index;
369       return RET_OK;
370     }
371   }
372   return RET_NO_CHANGE;
373 }
374 
KeepGraphInputsInOrder(const Model * model)375 int TrainExport::KeepGraphInputsInOrder(const Model *model) {
376   MS_CHECK_TRUE_MSG(model != nullptr, RET_ERROR, "input param model is nullptr.");
377   MS_CHECK_TRUE_MSG(meta_graph_->inputIndex.size() <= model->graph_.input_indices_.size(), RET_ERROR,
378                     "export model input indices size is large than origin input indices size.");
379   std::vector<uint32_t> origin_inputs_order;
380   for (auto index : model->graph_.input_indices_) {
381     MS_CHECK_TRUE_MSG(index < model->graph_.all_tensors_.size(), RET_ERROR, "input index out of range.");
382     auto ori_input_tensor = model->graph_.all_tensors_[index];
383     size_t meta_graph_input_index;
384     auto status =
385       FindSchemaTensorByName(meta_graph_->inputIndex, ori_input_tensor->name()->str(), &meta_graph_input_index);
386     if (status == RET_NO_CHANGE) {
387       MS_LOG(DEBUG) << "can't find tensor: " << ori_input_tensor->name()->str() << " in exported graph.";
388       continue;
389     } else if (status != RET_OK) {
390       MS_LOG(ERROR) << "find schema tensor failed.";
391       return RET_ERROR;
392     }
393     MS_CHECK_TRUE_MSG(status != RET_ERROR, RET_ERROR, "find graph input tensor failed.");
394     origin_inputs_order.emplace_back(meta_graph_input_index);
395   }
396   meta_graph_->inputIndex = origin_inputs_order;
397   if (!meta_graph_->subGraph.empty()) {
398     meta_graph_->subGraph[0]->inputIndices = origin_inputs_order;
399   }
400   return RET_OK;
401 }
ExportTensor(const Model * model,const std::vector<mindspore::lite::Tensor * > & tensors,int offset,const std::vector<mindspore::lite::Tensor * > const_folded_output,const std::vector<std::pair<size_t,tensor_info>> & map_index,const std::vector<std::string> & output_names,const std::set<size_t> & out_set)402 int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::lite::Tensor *> &tensors, int offset,
403                               const std::vector<mindspore::lite::Tensor *> const_folded_output,
404                               const std::vector<std::pair<size_t, tensor_info>> &map_index,
405                               const std::vector<std::string> &output_names, const std::set<size_t> &out_set) {
406   std::vector<mindspore::lite::Tensor *> in_tensors;
407   for (auto index : map_index) {
408     auto id = index.first;
409     size_t pid = id - static_cast<size_t>(offset);
410     mindspore::lite::Tensor *tensor = tensors.at(pid);
411     in_tensors.push_back(tensor);
412   }
413   std::map<std::string, uint32_t> ordered_output_names;
414   for (auto index : map_index) {
415     auto id = index.first;
416     size_t pid = id - static_cast<size_t>(offset);
417     mindspore::lite::Tensor *tensor = tensors.at(pid);
418     schema::Tensor *scTensor = model->graph_.all_tensors_.at(pid);
419     auto preferred_dim = WeightDecoder::GetPreferredDim(in_tensors, index.second.op_parameter, index.second.input_index,
420                                                         tensor->shape(), model->graph_.version_);
421     auto tensorT =
422       CreateTensor(tensor, const_folded_output, scTensor, preferred_dim, index.second.op_parameter->quant_type_);
423     if (tensorT == nullptr) {
424       MS_LOG(ERROR) << "error in tensor creation";
425       return RET_ERROR;
426     }
427     if (out_set.find(remap_[id]) == out_set.end()) {
428       if (IsInputTensor(*tensorT)) {
429         meta_graph_->inputIndex.push_back(remap_[id]);
430         if (!meta_graph_->subGraph.empty()) {
431           meta_graph_->subGraph[0]->inputIndices.push_back(remap_[id]);
432         }
433       }
434     }
435     // find output tensor
436     if (std::find(output_names.begin(), output_names.end(), tensor->tensor_name()) != output_names.end()) {
437       ordered_output_names[tensor->tensor_name()] = remap_[id];
438     }
439     meta_graph_->allTensors.emplace_back(std::move(tensorT));
440     if (!meta_graph_->subGraph.empty()) {
441       meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
442     }
443   }
444   for (auto &output_name : output_names) {
445     if (ordered_output_names.find(output_name) != ordered_output_names.end()) {
446       meta_graph_->outputIndex.push_back(ordered_output_names[output_name]);
447       if (!meta_graph_->subGraph.empty()) {
448         meta_graph_->subGraph[0]->outputIndices.push_back(ordered_output_names[output_name]);
449       }
450     }
451   }
452   return RET_OK;
453 }
454 
ExportNet(const std::vector<mindspore::kernel::KernelExec * > & kernels,const std::vector<mindspore::lite::Tensor * > & tensors,const std::vector<mindspore::lite::Tensor * > const_folded_output,const std::vector<std::string> & output_names,const Model * model,QuantizationType quant_type,const Model * bb_model)455 int TrainExport::ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels,
456                            const std::vector<mindspore::lite::Tensor *> &tensors,
457                            const std::vector<mindspore::lite::Tensor *> const_folded_output,
458                            const std::vector<std::string> &output_names, const Model *model,
459                            QuantizationType quant_type, const Model *bb_model) {
460   std::vector<std::pair<size_t, tensor_info>> map_index;
461   std::set<size_t> out_set;
462   if (meta_graph_ == nullptr) {
463     int status = ExportInit(model->graph_.name_, model->graph_.version_);
464     if (status != RET_OK) {
465       return status;
466     }
467   }
468   int offset = meta_graph_->allTensors.size();
469   int tensor_idx = offset;
470   quant_type_ = quant_type;
471   PrepareRemap(offset);
472 
473   for (const auto kernel : kernels) {
474     std::vector<uint32_t> in_idx, out_idx;
475     size_t input_index = 0;
476     for (const auto tensor : kernel->in_tensors()) {
477       size_t id = TSFindTensor(tensors, tensor) + static_cast<size_t>(offset);
478       if (id == tensors.size()) {
479         MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
480         return RET_ERROR;
481       }
482       auto it = remap_.find(id);
483       if (it == remap_.end()) {
484         remap_[id] = tensor_idx;
485         in_idx.push_back(tensor_idx);
486         map_index.push_back({id, {input_index++, kernel->op_parameter()}});
487         tensor_idx++;
488       } else {
489         in_idx.push_back(it->second);
490       }
491     }
492     size_t output_index = 0;
493     for (const auto tensor : kernel->out_tensors()) {
494       size_t id = TSFindTensor(tensors, tensor) + offset;
495       if (id == tensors.size()) {
496         MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
497         return RET_ERROR;
498       }
499       auto it = remap_.find(id);
500       if (it == remap_.end()) {
501         remap_[id] = tensor_idx;
502         map_index.push_back({id, {output_index++, kernel->op_parameter()}});
503         out_idx.push_back(tensor_idx);
504         out_set.insert(tensor_idx);
505         tensor_idx++;
506       } else {
507         out_idx.push_back(it->second);
508         out_set.insert(it->second);
509       }
510     }
511     auto ret = CreateAndAddCNode(kernel, in_idx, out_idx, model);
512     if (ret != RET_OK) {
513       MS_LOG(ERROR) << "failed to create cnode";
514       return ret;
515     }
516   }
517 
518   auto status = ExportTensor(model, tensors, offset, const_folded_output, map_index, output_names, out_set);
519   if (status != RET_OK) {
520     MS_LOG(ERROR) << "ExportTensor failed.";
521     return RET_ERROR;
522   }
523   auto origin_input_model = bb_model == nullptr ? model : bb_model;
524   status = KeepGraphInputsInOrder(origin_input_model);
525   if (status != RET_OK) {
526     MS_LOG(ERROR) << "keep graph inputs in order failed.";
527     return RET_ERROR;
528   }
529   TagQuantizedNodes();  // do another loop to mark QUANT_WEIGHT_NODES
530   status = TopologicalSort();
531   if (status != RET_OK) {
532     MS_LOG(ERROR) << "TopologicalSort failed.";
533     return RET_ERROR;
534   }
535 
536   return RET_OK;
537 }
538 
TopologicalSort()539 int TrainExport::TopologicalSort() {
540   MS_ASSERT(meta_graph_ != nullptr);
541   std::vector<std::unique_ptr<schema::CNodeT>> new_nodes;
542   std::vector<size_t> sinked_tensor_idxes;
543   for (auto &subgraph : meta_graph_->subGraph) {
544     std::copy(subgraph->inputIndices.begin(), subgraph->inputIndices.end(), std::back_inserter(sinked_tensor_idxes));
545   }
546   // put all const tensor index into sinked_tensor_idxes
547   for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) {
548     if (meta_graph_->allTensors.at(i)->nodeType == NodeType_ValueNode) {
549       sinked_tensor_idxes.push_back(i);
550     }
551   }
552   auto &old_nodes = meta_graph_->nodes;
553   std::queue<std::unique_ptr<schema::CNodeT>> op_queue;
554   // put all none depend node into queue
555   for (size_t i = 0; i < meta_graph_->subGraph.size(); i++) {
556     std::vector<unsigned int> new_subgraph_node_indices = {};
557     auto subgraph_node_indices = meta_graph_->subGraph[i]->nodeIndices;
558 
559     for (size_t j = 0; j < subgraph_node_indices.size(); j++) {
560       auto &node = old_nodes[subgraph_node_indices[j]];
561       if (IsNodeNonDepend(node, sinked_tensor_idxes)) {
562         sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
563         op_queue.push(std::move(node));
564       }
565     }
566     while (!op_queue.empty()) {
567       auto &node = op_queue.front();
568       auto post_node_idxes = GetOutputNodeIdx(*meta_graph_, *(node.get()));
569       sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
570       for (auto post_node_idx : post_node_idxes) {
571         if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) {
572           auto &post_node = old_nodes.at(post_node_idx);
573           // check if post_node is non-depended
574           if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) {
575             op_queue.push(std::move(post_node));
576           }
577         }
578       }
579       new_nodes.emplace_back(std::move(node));
580       new_subgraph_node_indices.push_back(new_nodes.size() - 1);
581       op_queue.pop();
582     }
583     meta_graph_->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices);
584   }
585   if (new_nodes.size() != old_nodes.size()) {
586     MS_LOG(ERROR) << "Unknown error in TopologicalSort, old_nodes size: " << old_nodes.size()
587                   << ", new_nodes size: " << new_nodes.size();
588     return RET_ERROR;
589   }
590   meta_graph_->nodes.swap(new_nodes);
591   return RET_OK;
592 }
593 
IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> & node,const std::vector<size_t> & sinked_tensor_idxes)594 bool TrainExport::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
595                                   const std::vector<size_t> &sinked_tensor_idxes) {
596   MS_ASSERT(node != nullptr);
597   return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
598                      [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); });
599 }
600 
ExportInit(const std::string model_name,std::string version)601 int TrainExport::ExportInit(const std::string model_name, std::string version) {
602   meta_graph_ = new (std::nothrow) schema::MetaGraphT();
603   if (meta_graph_ == nullptr) {
604     MS_LOG(ERROR) << "cannot allocate meta_graph";
605     return RET_ERROR;
606   }
607   auto sub_graph = std::make_unique<schema::SubGraphT>();
608   if (sub_graph == nullptr) {
609     MS_LOG(ERROR) << "cannot allocate SubGraphT";
610     return RET_ERROR;
611   }
612   sub_graph->name = model_name + "_subgraph";
613   meta_graph_->subGraph.emplace_back(std::move(sub_graph));
614   meta_graph_->fmkType = kFmkVal;
615   meta_graph_->name = model_name;
616   meta_graph_->version = version;
617   return RET_OK;
618 }
619 
SaveModel(lite::Model * model,const std::string & file_name)620 int TrainExport::SaveModel(lite::Model *model, const std::string &file_name) {
621   std::string filename = file_name;
622   if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
623     filename = filename + ".ms";
624   }
625 #ifndef _MSC_VER
626   if (access(filename.c_str(), F_OK) == 0) {
627     chmod(filename.c_str(), S_IWUSR);
628   }
629 #endif
630   int status = mindspore::lite::Model::Export(model, filename.c_str());
631   return status;
632 }
633 
SaveModel(lite::Model * model,Buffer * model_buffer)634 int TrainExport::SaveModel(lite::Model *model, Buffer *model_buffer) {
635   MS_CHECK_FALSE_MSG(model == nullptr, RET_ERROR, "model cannot be empty.");
636   MS_CHECK_FALSE_MSG(model_buffer == nullptr, RET_ERROR, "model_buffer cannot be empty.");
637   auto *liteModel = reinterpret_cast<LiteModel *>(model);
638   auto size = liteModel->buf_size_;
639   model_buffer->ResizeData(size);
640 
641   size_t out_size = model_buffer->DataSize();
642   int status = mindspore::lite::Model::Export(model, static_cast<char *>(model_buffer->MutableData()), &out_size);
643   if (out_size != size) {
644     MS_LOG(ERROR) << "model_buffer resize failed.";
645     return RET_ERROR;
646   }
647 
648   return status;
649 }
650 
SaveToFile()651 int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
652 
SaveToBuffer()653 int TrainExport::SaveToBuffer() {
654   constexpr size_t kFbBuilderInitSize = 1024;
655   flatbuffers::FlatBufferBuilder builder(kFbBuilderInitSize);
656   auto offset = schema::MetaGraph::Pack(builder, meta_graph_);
657   builder.Finish(offset);
658   schema::FinishMetaGraphBuffer(builder, offset);
659   size_t size = builder.GetSize();
660   auto content = builder.GetBufferPointer();
661   MS_CHECK_FALSE_MSG(content == nullptr, RET_ERROR, "context cannot be empty.");
662   MS_CHECK_FALSE_MSG(model_buffer_ == nullptr, RET_ERROR, "context cannot be empty.");
663   model_buffer_->SetData(content, size);
664   return RET_OK;
665 }
SaveWeightsToFile(bool enable_fp16,const std::vector<std::string> & changeable_weights_name)666 int TrainExport::SaveWeightsToFile(bool enable_fp16, const std::vector<std::string> &changeable_weights_name) {
667   const auto &all_tensors = meta_graph_->allTensors;
668   std::ofstream weights(file_name_, std::ios::out | std::ios::trunc | std::ios::binary);
669   if (!weights.is_open()) {
670     MS_LOG(ERROR) << "Can not open weight file: " << file_name_;
671     return RET_ERROR;
672   }
673   for (auto &tensor : all_tensors) {
674     MS_CHECK_TRUE_MSG(tensor != nullptr, RET_NULL_PTR, "Exist tensor is a nullptr.");
675     if (tensor->data.empty()) {
676       continue;
677     }
678     if (std::find(changeable_weights_name.begin(), changeable_weights_name.end(), tensor->name) !=
679         changeable_weights_name.end()) {
680       auto shape = tensor->dims;
681       weights.write(reinterpret_cast<const char *>(shape.data()), shape.size() * sizeof(uint32_t));
682       if (weights.fail()) {
683         MS_LOG(ERROR) << "Write weights failed, weight file: " << file_name_;
684         weights.close();
685         return RET_ERROR;
686       }
687     }
688     if (!enable_fp16 || tensor->dataType != kNumberTypeFloat32) {
689       weights.write(reinterpret_cast<const char *>(tensor->data.data()), tensor->data.size());
690       if (weights.fail()) {
691         MS_LOG(ERROR) << "Write weights failed, weight file: " << file_name_;
692         weights.close();
693         return RET_ERROR;
694       }
695     } else {
696       std::vector<uint16_t> data_fp16(tensor->data.size() / sizeof(float));
697 #ifndef ENABLE_ARM
698       auto fp32_data = reinterpret_cast<const float *>(tensor->data.data());
699       auto fp16_data = reinterpret_cast<float16 *>(data_fp16.data());
700       CHECK_NULL_RETURN(fp32_data);
701       CHECK_NULL_RETURN(fp16_data);
702       for (size_t j = 0; j < data_fp16.size(); ++j) {
703         fp16_data[j] = float16(fp32_data[j]);
704       }
705 #else
706       Float32ToFloat16_fp16_handler(tensor->data.data(), data_fp16.data(), data_fp16.size(), true);
707 #endif
708       weights.write(reinterpret_cast<const char *>(data_fp16.data()), data_fp16.size() * sizeof(uint16_t));
709       if (weights.fail()) {
710         MS_LOG(ERROR) << "Write weights failed, weight file: " << file_name_;
711         weights.close();
712         return RET_ERROR;
713       }
714     }
715   }
716   weights.close();
717 #ifndef _MSC_VER
718   chmod(file_name_.c_str(), S_IRUSR | S_IWUSR);
719 #endif
720   return RET_OK;
721 }
722 
IsInputTensor(const schema::TensorT & t)723 bool TrainExport::IsInputTensor(const schema::TensorT &t) {
724   int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
725   return ((t.data.size() == 0) && (total_dims != 0));
726 }
727 
TrainModelFusion()728 int TrainExport::TrainModelFusion() {
729   GraphFusion graph_fusion;
730   auto status = graph_fusion.Run(meta_graph_);
731   if (status != RET_OK) {
732     return RET_ERROR;
733   }
734   return RET_OK;
735 }
736 
TrainModelDrop()737 int TrainExport::TrainModelDrop() {
738   GraphDropout graph_dropout;
739   auto status = graph_dropout.Run(meta_graph_);
740   if (status != RET_OK) {
741     return RET_ERROR;
742   }
743   return RET_OK;
744 }
745 
~TrainExport()746 TrainExport::~TrainExport() { delete meta_graph_; }
747 }  // namespace lite
748 }  // namespace mindspore
749