1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/train/train_export.h"
17 #include <unistd.h>
18 #include <sys/stat.h>
19 #include <fstream>
20 #include <utility>
21 #include <queue>
22 #include <algorithm>
23 #include <functional>
24 #include <map>
25 #include <set>
26 #include "schema/inner/model_generated.h"
27 #include "src/train/train_utils.h"
28 #include "src/common/quant_utils.h"
29 #include "src/common/storage.h"
30 #include "src/train/graph_fusion.h"
31 #include "src/train/graph_dropout.h"
32 #include "src/litert/weight_decoder.h"
33 #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h"
34 #include "base/float16.h"
35
36 namespace mindspore {
37 namespace lite {
38 namespace {
39 constexpr static int kFmkVal = 3;
40 constexpr static int kTransformTensorDim = 4;
GetLinkedPostIdx(const schema::MetaGraphT & graphT,const size_t & tensorIdx)41 std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
42 std::vector<size_t> postNodeIdx;
43 for (size_t i = 0; i < graphT.nodes.size(); i++) {
44 auto &oldNode = graphT.nodes.at(i);
45 if (oldNode == nullptr) {
46 continue;
47 }
48 auto inputIndexes = oldNode->inputIndex;
49 if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
50 postNodeIdx.emplace_back(i);
51 }
52 }
53 return postNodeIdx;
54 }
55
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const schema::CNodeT & node,const int outputIndexIdx=-1)56 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
57 const int outputIndexIdx = -1) {
58 std::vector<uint32_t> outputIndexes;
59 if (outputIndexIdx == -1) {
60 outputIndexes = node.outputIndex;
61 } else {
62 outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
63 }
64 std::set<size_t> outputNodeIdx;
65 for (uint32_t outputIdx : outputIndexes) {
66 auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
67 outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
68 }
69 std::vector<size_t> ret;
70 ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
71 return ret;
72 }
73 } // namespace
74
CreateData(const lite::Tensor * tensor)75 std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) {
76 uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data());
77 auto size = tensor->Size();
78 std::vector<uint8_t> data(tensor_data, tensor_data + size);
79 return data;
80 }
81
NeedQuantization(const lite::Tensor * t,const int tensor_quant_type)82 bool TrainExport::NeedQuantization(const lite::Tensor *t, const int tensor_quant_type) {
83 return ((quant_type_ == QT_WEIGHT && t->shape().size() > 1) ||
84 ((quant_type_ == QT_DEFAULT) && (tensor_quant_type == schema::QuantType_QUANT_WEIGHT)));
85 }
86
GetNodeQuantType(const mindspore::kernel::KernelExec * kernel)87 schema::QuantType TrainExport::GetNodeQuantType(const mindspore::kernel::KernelExec *kernel) {
88 return static_cast<schema::QuantType>(kernel->op_parameter()->quant_type_);
89 }
90
TagQuantizedNodes()91 void TrainExport::TagQuantizedNodes() {
92 if (quant_type_ == QT_WEIGHT) {
93 for (auto &node : meta_graph_->nodes) {
94 if (node->quantType != schema::QuantType_QUANT_WEIGHT) {
95 for (auto t_idx : node->inputIndex) {
96 if ((meta_graph_->allTensors.at(t_idx)->nodeType == NodeType_ValueNode) &&
97 (meta_graph_->allTensors.at(t_idx)->quantParams.size() > 0)) {
98 node->quantType = schema::QuantType_QUANT_WEIGHT;
99 }
100 }
101 }
102 }
103 }
104 }
105
QuantTensorData(schema::TensorT * dest_tensor,const lite::Tensor * src_tensor,int preferred_dim)106 int TrainExport::QuantTensorData(schema::TensorT *dest_tensor, const lite::Tensor *src_tensor, int preferred_dim) {
107 int channels = 1;
108 int bit_num = 8;
109
110 if (src_tensor->quant_params().size() > 0) {
111 channels = src_tensor->quant_params().size();
112 bit_num = src_tensor->quant_params().at(0).bitNum;
113 }
114 if (channels < 1) {
115 MS_LOG(ERROR) << "Quant Params is empty";
116 return RET_ERROR;
117 }
118 int quant_max = QuantMax(bit_num, false);
119 int quant_min = QuantMin(bit_num, false);
120 std::vector<int8_t> data(src_tensor->ElementsNum());
121 std::vector<schema::QuantParamT> quant_params;
122
123 STATUS ret = RET_OK;
124 if (channels == kPerTensor) {
125 ret = DoPerLayerQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
126 &(quant_params), quant_max, quant_min, bit_num, &data, false, false);
127 } else {
128 ret = DoPerChannelQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
129 &(quant_params), quant_max, quant_min, bit_num, &data, dest_tensor->dims,
130 preferred_dim, true, false, false);
131 }
132 if (ret == RET_NO_CHANGE) {
133 MS_LOG(DEBUG) << "No Need to quant per channel";
134 return RET_OK;
135 }
136 if (ret == RET_ERROR) {
137 MS_LOG(ERROR) << "QuantTensorData error, channels = " << channels;
138 return ret;
139 }
140 if (quant_params.empty()) {
141 MS_LOG(ERROR) << "quant_params empty";
142 return RET_ERROR;
143 }
144 dest_tensor->data = std::vector<uint8_t>(data.data(), data.data() + data.size());
145 dest_tensor->dataType = kNumberTypeInt8;
146 dest_tensor->quantParams.clear();
147 for (auto quant_param : quant_params) {
148 dest_tensor->quantParams.emplace_back(std::make_unique<schema::QuantParamT>(quant_param));
149 }
150
151 return RET_OK;
152 }
153
CreateTensor(const mindspore::lite::Tensor * tensor,const std::vector<mindspore::lite::Tensor * > const_folded_output,schema::Tensor * scTensor,int preferred_dim,const int tensor_quant_type)154 std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(
155 const mindspore::lite::Tensor *tensor, const std::vector<mindspore::lite::Tensor *> const_folded_output,
156 schema::Tensor *scTensor, int preferred_dim, const int tensor_quant_type) {
157 MS_CHECK_TRUE_RET(tensor != nullptr, nullptr);
158 MS_CHECK_TRUE_RET(scTensor != nullptr, nullptr);
159 auto tensorT = std::make_unique<schema::TensorT>();
160 bool const_fold = false;
161 if (quant_type_ == QT_NONE && !const_folded_output.empty() &&
162 std::find(const_folded_output.begin(), const_folded_output.end(), tensor) != const_folded_output.end()) {
163 tensorT->nodeType = NodeType_ValueNode;
164 const_fold = true;
165 } else {
166 tensorT->nodeType = scTensor->nodeType();
167 }
168 tensorT->dims = tensor->shape();
169 tensorT->format = static_cast<schema::Format>(tensor->format());
170 tensorT->name = tensor->tensor_name();
171 tensorT->refCount = 0;
172 tensorT->offset = 0;
173 tensorT->dataType = tensor->data_type();
174 tensorT->enableHuffmanCode = false;
175 if (((tensorT->nodeType == NodeType_ValueNode) && (scTensor->data() != nullptr) && (scTensor->data()->size() > 0)) ||
176 const_fold) {
177 if (NeedQuantization(tensor, tensor_quant_type)) {
178 auto ret = QuantTensorData(tensorT.get(), tensor, preferred_dim);
179 if (ret != RET_OK) {
180 MS_LOG(ERROR) << "QuantTensorData failed.";
181 return nullptr;
182 }
183 } else {
184 tensorT->data = CreateData(tensor);
185 }
186 }
187 tensorT->quantClusters = tensor->quant_clusters();
188 return tensorT;
189 }
190
FindNode(const mindspore::kernel::KernelExec * kernel,const Model * model)191 LiteGraph::Node *TrainExport::FindNode(const mindspore::kernel::KernelExec *kernel, const Model *model) {
192 auto nodes = model->graph_.all_nodes_;
193 auto it = std::find_if(nodes.begin(), nodes.end(),
194 [&kernel](mindspore::lite::LiteGraph::Node *n) { return (kernel->name() == n->name_); });
195 if (it == nodes.end()) {
196 return nullptr;
197 }
198 return *it;
199 }
200
CreateAndAddCNode(const mindspore::kernel::KernelExec * kernel,std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,const Model * model)201 int TrainExport::CreateAndAddCNode(const mindspore::kernel::KernelExec *kernel, std::vector<uint32_t> inputIndex,
202 std::vector<uint32_t> outputIndex, const Model *model) {
203 auto cnode = CreateCNode(kernel, inputIndex, outputIndex, model);
204 if (cnode == nullptr) {
205 MS_LOG(ERROR) << "failed to create cnode";
206 return RET_ERROR;
207 }
208 meta_graph_->nodes.emplace_back(std::move(cnode));
209 if (!meta_graph_->subGraph.empty()) {
210 meta_graph_->subGraph[0]->nodeIndices.push_back(meta_graph_->nodes.size() - 1);
211 }
212 return RET_OK;
213 }
214
CreateCNode(const mindspore::kernel::KernelExec * kernel,std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,const Model * model)215 std::unique_ptr<schema::CNodeT> TrainExport::CreateCNode(const mindspore::kernel::KernelExec *kernel,
216 std::vector<uint32_t> inputIndex,
217 std::vector<uint32_t> outputIndex, const Model *model) {
218 auto cnodeT = std::make_unique<schema::CNodeT>();
219 if (cnodeT == nullptr) {
220 MS_LOG(ERROR) << " cannot allocate node";
221 return nullptr;
222 }
223 cnodeT->inputIndex = inputIndex;
224 cnodeT->outputIndex = outputIndex;
225 cnodeT->name = kernel->name();
226 cnodeT->quantType = GetNodeQuantType(kernel);
227 // find kernel in model
228 auto *node = FindNode(kernel, model);
229 if (node == nullptr) {
230 MS_LOG(ERROR) << "cannot find kernel " + kernel->name() + " in model";
231 return nullptr;
232 }
233 auto primitive = reinterpret_cast<schema::Primitive *>(const_cast<void *>(node->primitive_));
234 cnodeT->primitive = std::unique_ptr<schema::PrimitiveT>(primitive->UnPack());
235 return cnodeT;
236 }
237
LoadModel(void * buf,size_t buf_size)238 int TrainExport::LoadModel(void *buf, size_t buf_size) {
239 flatbuffers::Verifier verify((const uint8_t *)buf, buf_size);
240 if (!schema::VerifyMetaGraphBuffer(verify)) {
241 MS_LOG(ERROR) << "model flatbuffer verify fail";
242 return RET_ERROR;
243 }
244 meta_graph_ = schema::GetMetaGraph(buf)->UnPack();
245 meta_graph_->outputIndex.clear();
246 if (!meta_graph_->subGraph.empty()) {
247 meta_graph_->subGraph[0]->outputIndices.clear();
248 }
249 return RET_OK;
250 }
251
CreateTransformTensor(size_t id)252 std::unique_ptr<schema::TensorT> TrainExport::CreateTransformTensor(size_t id) {
253 auto &scTensor = meta_graph_->allTensors.at(id);
254 auto tensorT = std::make_unique<schema::TensorT>();
255 if (tensorT == nullptr) {
256 MS_LOG(ERROR) << "Could not create tensor ";
257 return nullptr;
258 }
259 tensorT->nodeType = scTensor->nodeType;
260 tensorT->dataType = scTensor->dataType;
261 std::vector<int32_t> dims;
262 std::vector<int32_t> val = {0, 2, 3, 1};
263 if (scTensor->dims.size() == kTransformTensorDim) {
264 for (size_t i = 0; i < val.size(); i++) {
265 dims.push_back(scTensor->dims.at(val[i]));
266 }
267 tensorT->dims = dims;
268 } else {
269 tensorT->dims = scTensor->dims;
270 }
271 tensorT->format = schema::Format_NHWC;
272 tensorT->name = scTensor->name + "_post";
273 tensorT->refCount = 0;
274 tensorT->offset = 0;
275 tensorT->enableHuffmanCode = false;
276 return tensorT;
277 }
278
CreateTransformConst(size_t last_id)279 std::unique_ptr<schema::TensorT> TrainExport::CreateTransformConst(size_t last_id) {
280 auto tensorT = std::make_unique<schema::TensorT>();
281 if (tensorT == nullptr) {
282 MS_LOG(ERROR) << "Could not create tensor ";
283 return nullptr;
284 }
285 tensorT->nodeType = lite::NodeType_ValueNode;
286 tensorT->dataType = TypeId::kNumberTypeInt32;
287 tensorT->dims = {kTransformTensorDim};
288 tensorT->format = schema::Format_NCHW;
289 tensorT->name = "const-" + std::to_string(last_id);
290 tensorT->refCount = 0;
291 tensorT->offset = 0;
292 tensorT->enableHuffmanCode = false;
293 int32_t val[] = {0, 2, 3, 1};
294 uint8_t *valp = reinterpret_cast<uint8_t *>(val);
295 tensorT->data = std::vector<uint8_t>(valp, valp + sizeof(val));
296 return tensorT;
297 }
298
CreateTransformNode(std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,size_t id)299 std::unique_ptr<schema::CNodeT> TrainExport::CreateTransformNode(std::vector<uint32_t> inputIndex,
300 std::vector<uint32_t> outputIndex, size_t id) {
301 auto cnodeT = std::make_unique<schema::CNodeT>();
302 if (cnodeT == nullptr) {
303 MS_LOG(ERROR) << "cannot allocate node";
304 return nullptr;
305 }
306 cnodeT->inputIndex = inputIndex;
307 cnodeT->outputIndex = outputIndex;
308 cnodeT->name = "transpose-" + std::to_string(id);
309 cnodeT->quantType = schema::QuantType_QUANT_NONE;
310 cnodeT->primitive = std::make_unique<schema::PrimitiveT>();
311 cnodeT->primitive->value.type = schema::PrimitiveType_Transpose;
312 return cnodeT;
313 }
314
AddTransformNode()315 int TrainExport::AddTransformNode() {
316 std::unordered_map<size_t, size_t> reconnect;
317 size_t last_id = meta_graph_->allTensors.size();
318 size_t last_node = meta_graph_->nodes.size();
319 for (auto it : connect_) {
320 auto tensorConst = CreateTransformConst(last_id);
321 if (tensorConst == nullptr) {
322 MS_LOG(ERROR) << "error in create tensor";
323 return RET_ERROR;
324 }
325 meta_graph_->allTensors.emplace_back(std::move(tensorConst)); // last_id
326 if (!meta_graph_->subGraph.empty()) {
327 meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
328 }
329 auto tensorT = CreateTransformTensor(it.second);
330 if (tensorT == nullptr) {
331 MS_LOG(ERROR) << "error in create tensor";
332 return RET_ERROR;
333 }
334 meta_graph_->allTensors.emplace_back(std::move(tensorT)); // last_id + 1
335 if (!meta_graph_->subGraph.empty()) {
336 meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
337 }
338 std::vector<uint32_t> in_idx = {static_cast<uint32_t>(it.second), static_cast<uint32_t>(last_id)};
339 std::vector<uint32_t> out_idx = {static_cast<uint32_t>(last_id + 1)};
340 reconnect[it.first] = last_id + 1;
341 auto cnode = CreateTransformNode(in_idx, out_idx, last_node);
342 if (cnode == nullptr) {
343 MS_LOG(ERROR) << "error in node creation";
344 return RET_ERROR;
345 }
346 meta_graph_->nodes.emplace_back(std::move(cnode));
347 if (!meta_graph_->subGraph.empty()) {
348 meta_graph_->subGraph[0]->nodeIndices.push_back(meta_graph_->nodes.size() - 1);
349 }
350 }
351 connect_ = reconnect;
352 return RET_OK;
353 }
354
PrepareRemap(int offset)355 void TrainExport::PrepareRemap(int offset) {
356 for (auto it : connect_) {
357 remap_[it.first + offset] = it.second;
358 }
359 }
360
FindSchemaTensorByName(const std::vector<uint32_t> & search_indices,const std::string & search_name,size_t * target_index)361 int TrainExport::FindSchemaTensorByName(const std::vector<uint32_t> &search_indices, const std::string &search_name,
362 size_t *target_index) {
363 MS_CHECK_TRUE_MSG(target_index != nullptr, RET_ERROR, "input param target_index is nullptr.");
364 auto total_size = meta_graph_->allTensors.size();
365 for (auto index : search_indices) {
366 MS_CHECK_TRUE_MSG(index < total_size, RET_ERROR, "index is out of range.");
367 if (meta_graph_->allTensors[index]->name == search_name) {
368 *target_index = index;
369 return RET_OK;
370 }
371 }
372 return RET_NO_CHANGE;
373 }
374
KeepGraphInputsInOrder(const Model * model)375 int TrainExport::KeepGraphInputsInOrder(const Model *model) {
376 MS_CHECK_TRUE_MSG(model != nullptr, RET_ERROR, "input param model is nullptr.");
377 MS_CHECK_TRUE_MSG(meta_graph_->inputIndex.size() <= model->graph_.input_indices_.size(), RET_ERROR,
378 "export model input indices size is large than origin input indices size.");
379 std::vector<uint32_t> origin_inputs_order;
380 for (auto index : model->graph_.input_indices_) {
381 MS_CHECK_TRUE_MSG(index < model->graph_.all_tensors_.size(), RET_ERROR, "input index out of range.");
382 auto ori_input_tensor = model->graph_.all_tensors_[index];
383 size_t meta_graph_input_index;
384 auto status =
385 FindSchemaTensorByName(meta_graph_->inputIndex, ori_input_tensor->name()->str(), &meta_graph_input_index);
386 if (status == RET_NO_CHANGE) {
387 MS_LOG(DEBUG) << "can't find tensor: " << ori_input_tensor->name()->str() << " in exported graph.";
388 continue;
389 } else if (status != RET_OK) {
390 MS_LOG(ERROR) << "find schema tensor failed.";
391 return RET_ERROR;
392 }
393 MS_CHECK_TRUE_MSG(status != RET_ERROR, RET_ERROR, "find graph input tensor failed.");
394 origin_inputs_order.emplace_back(meta_graph_input_index);
395 }
396 meta_graph_->inputIndex = origin_inputs_order;
397 if (!meta_graph_->subGraph.empty()) {
398 meta_graph_->subGraph[0]->inputIndices = origin_inputs_order;
399 }
400 return RET_OK;
401 }
ExportTensor(const Model * model,const std::vector<mindspore::lite::Tensor * > & tensors,int offset,const std::vector<mindspore::lite::Tensor * > const_folded_output,const std::vector<std::pair<size_t,tensor_info>> & map_index,const std::vector<std::string> & output_names,const std::set<size_t> & out_set)402 int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::lite::Tensor *> &tensors, int offset,
403 const std::vector<mindspore::lite::Tensor *> const_folded_output,
404 const std::vector<std::pair<size_t, tensor_info>> &map_index,
405 const std::vector<std::string> &output_names, const std::set<size_t> &out_set) {
406 std::vector<mindspore::lite::Tensor *> in_tensors;
407 for (auto index : map_index) {
408 auto id = index.first;
409 size_t pid = id - static_cast<size_t>(offset);
410 mindspore::lite::Tensor *tensor = tensors.at(pid);
411 in_tensors.push_back(tensor);
412 }
413 std::map<std::string, uint32_t> ordered_output_names;
414 for (auto index : map_index) {
415 auto id = index.first;
416 size_t pid = id - static_cast<size_t>(offset);
417 mindspore::lite::Tensor *tensor = tensors.at(pid);
418 schema::Tensor *scTensor = model->graph_.all_tensors_.at(pid);
419 auto preferred_dim = WeightDecoder::GetPreferredDim(in_tensors, index.second.op_parameter, index.second.input_index,
420 tensor->shape(), model->graph_.version_);
421 auto tensorT =
422 CreateTensor(tensor, const_folded_output, scTensor, preferred_dim, index.second.op_parameter->quant_type_);
423 if (tensorT == nullptr) {
424 MS_LOG(ERROR) << "error in tensor creation";
425 return RET_ERROR;
426 }
427 if (out_set.find(remap_[id]) == out_set.end()) {
428 if (IsInputTensor(*tensorT)) {
429 meta_graph_->inputIndex.push_back(remap_[id]);
430 if (!meta_graph_->subGraph.empty()) {
431 meta_graph_->subGraph[0]->inputIndices.push_back(remap_[id]);
432 }
433 }
434 }
435 // find output tensor
436 if (std::find(output_names.begin(), output_names.end(), tensor->tensor_name()) != output_names.end()) {
437 ordered_output_names[tensor->tensor_name()] = remap_[id];
438 }
439 meta_graph_->allTensors.emplace_back(std::move(tensorT));
440 if (!meta_graph_->subGraph.empty()) {
441 meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
442 }
443 }
444 for (auto &output_name : output_names) {
445 if (ordered_output_names.find(output_name) != ordered_output_names.end()) {
446 meta_graph_->outputIndex.push_back(ordered_output_names[output_name]);
447 if (!meta_graph_->subGraph.empty()) {
448 meta_graph_->subGraph[0]->outputIndices.push_back(ordered_output_names[output_name]);
449 }
450 }
451 }
452 return RET_OK;
453 }
454
ExportNet(const std::vector<mindspore::kernel::KernelExec * > & kernels,const std::vector<mindspore::lite::Tensor * > & tensors,const std::vector<mindspore::lite::Tensor * > const_folded_output,const std::vector<std::string> & output_names,const Model * model,QuantizationType quant_type,const Model * bb_model)455 int TrainExport::ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels,
456 const std::vector<mindspore::lite::Tensor *> &tensors,
457 const std::vector<mindspore::lite::Tensor *> const_folded_output,
458 const std::vector<std::string> &output_names, const Model *model,
459 QuantizationType quant_type, const Model *bb_model) {
460 std::vector<std::pair<size_t, tensor_info>> map_index;
461 std::set<size_t> out_set;
462 if (meta_graph_ == nullptr) {
463 int status = ExportInit(model->graph_.name_, model->graph_.version_);
464 if (status != RET_OK) {
465 return status;
466 }
467 }
468 int offset = meta_graph_->allTensors.size();
469 int tensor_idx = offset;
470 quant_type_ = quant_type;
471 PrepareRemap(offset);
472
473 for (const auto kernel : kernels) {
474 std::vector<uint32_t> in_idx, out_idx;
475 size_t input_index = 0;
476 for (const auto tensor : kernel->in_tensors()) {
477 size_t id = TSFindTensor(tensors, tensor) + static_cast<size_t>(offset);
478 if (id == tensors.size()) {
479 MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
480 return RET_ERROR;
481 }
482 auto it = remap_.find(id);
483 if (it == remap_.end()) {
484 remap_[id] = tensor_idx;
485 in_idx.push_back(tensor_idx);
486 map_index.push_back({id, {input_index++, kernel->op_parameter()}});
487 tensor_idx++;
488 } else {
489 in_idx.push_back(it->second);
490 }
491 }
492 size_t output_index = 0;
493 for (const auto tensor : kernel->out_tensors()) {
494 size_t id = TSFindTensor(tensors, tensor) + offset;
495 if (id == tensors.size()) {
496 MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
497 return RET_ERROR;
498 }
499 auto it = remap_.find(id);
500 if (it == remap_.end()) {
501 remap_[id] = tensor_idx;
502 map_index.push_back({id, {output_index++, kernel->op_parameter()}});
503 out_idx.push_back(tensor_idx);
504 out_set.insert(tensor_idx);
505 tensor_idx++;
506 } else {
507 out_idx.push_back(it->second);
508 out_set.insert(it->second);
509 }
510 }
511 auto ret = CreateAndAddCNode(kernel, in_idx, out_idx, model);
512 if (ret != RET_OK) {
513 MS_LOG(ERROR) << "failed to create cnode";
514 return ret;
515 }
516 }
517
518 auto status = ExportTensor(model, tensors, offset, const_folded_output, map_index, output_names, out_set);
519 if (status != RET_OK) {
520 MS_LOG(ERROR) << "ExportTensor failed.";
521 return RET_ERROR;
522 }
523 auto origin_input_model = bb_model == nullptr ? model : bb_model;
524 status = KeepGraphInputsInOrder(origin_input_model);
525 if (status != RET_OK) {
526 MS_LOG(ERROR) << "keep graph inputs in order failed.";
527 return RET_ERROR;
528 }
529 TagQuantizedNodes(); // do another loop to mark QUANT_WEIGHT_NODES
530 status = TopologicalSort();
531 if (status != RET_OK) {
532 MS_LOG(ERROR) << "TopologicalSort failed.";
533 return RET_ERROR;
534 }
535
536 return RET_OK;
537 }
538
TopologicalSort()539 int TrainExport::TopologicalSort() {
540 MS_ASSERT(meta_graph_ != nullptr);
541 std::vector<std::unique_ptr<schema::CNodeT>> new_nodes;
542 std::vector<size_t> sinked_tensor_idxes;
543 for (auto &subgraph : meta_graph_->subGraph) {
544 std::copy(subgraph->inputIndices.begin(), subgraph->inputIndices.end(), std::back_inserter(sinked_tensor_idxes));
545 }
546 // put all const tensor index into sinked_tensor_idxes
547 for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) {
548 if (meta_graph_->allTensors.at(i)->nodeType == NodeType_ValueNode) {
549 sinked_tensor_idxes.push_back(i);
550 }
551 }
552 auto &old_nodes = meta_graph_->nodes;
553 std::queue<std::unique_ptr<schema::CNodeT>> op_queue;
554 // put all none depend node into queue
555 for (size_t i = 0; i < meta_graph_->subGraph.size(); i++) {
556 std::vector<unsigned int> new_subgraph_node_indices = {};
557 auto subgraph_node_indices = meta_graph_->subGraph[i]->nodeIndices;
558
559 for (size_t j = 0; j < subgraph_node_indices.size(); j++) {
560 auto &node = old_nodes[subgraph_node_indices[j]];
561 if (IsNodeNonDepend(node, sinked_tensor_idxes)) {
562 sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
563 op_queue.push(std::move(node));
564 }
565 }
566 while (!op_queue.empty()) {
567 auto &node = op_queue.front();
568 auto post_node_idxes = GetOutputNodeIdx(*meta_graph_, *(node.get()));
569 sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
570 for (auto post_node_idx : post_node_idxes) {
571 if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) {
572 auto &post_node = old_nodes.at(post_node_idx);
573 // check if post_node is non-depended
574 if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) {
575 op_queue.push(std::move(post_node));
576 }
577 }
578 }
579 new_nodes.emplace_back(std::move(node));
580 new_subgraph_node_indices.push_back(new_nodes.size() - 1);
581 op_queue.pop();
582 }
583 meta_graph_->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices);
584 }
585 if (new_nodes.size() != old_nodes.size()) {
586 MS_LOG(ERROR) << "Unknown error in TopologicalSort, old_nodes size: " << old_nodes.size()
587 << ", new_nodes size: " << new_nodes.size();
588 return RET_ERROR;
589 }
590 meta_graph_->nodes.swap(new_nodes);
591 return RET_OK;
592 }
593
IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> & node,const std::vector<size_t> & sinked_tensor_idxes)594 bool TrainExport::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
595 const std::vector<size_t> &sinked_tensor_idxes) {
596 MS_ASSERT(node != nullptr);
597 return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
598 [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); });
599 }
600
ExportInit(const std::string model_name,std::string version)601 int TrainExport::ExportInit(const std::string model_name, std::string version) {
602 meta_graph_ = new (std::nothrow) schema::MetaGraphT();
603 if (meta_graph_ == nullptr) {
604 MS_LOG(ERROR) << "cannot allocate meta_graph";
605 return RET_ERROR;
606 }
607 auto sub_graph = std::make_unique<schema::SubGraphT>();
608 if (sub_graph == nullptr) {
609 MS_LOG(ERROR) << "cannot allocate SubGraphT";
610 return RET_ERROR;
611 }
612 sub_graph->name = model_name + "_subgraph";
613 meta_graph_->subGraph.emplace_back(std::move(sub_graph));
614 meta_graph_->fmkType = kFmkVal;
615 meta_graph_->name = model_name;
616 meta_graph_->version = version;
617 return RET_OK;
618 }
619
SaveModel(lite::Model * model,const std::string & file_name)620 int TrainExport::SaveModel(lite::Model *model, const std::string &file_name) {
621 std::string filename = file_name;
622 if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
623 filename = filename + ".ms";
624 }
625 #ifndef _MSC_VER
626 if (access(filename.c_str(), F_OK) == 0) {
627 chmod(filename.c_str(), S_IWUSR);
628 }
629 #endif
630 int status = mindspore::lite::Model::Export(model, filename.c_str());
631 return status;
632 }
633
SaveModel(lite::Model * model,Buffer * model_buffer)634 int TrainExport::SaveModel(lite::Model *model, Buffer *model_buffer) {
635 MS_CHECK_FALSE_MSG(model == nullptr, RET_ERROR, "model cannot be empty.");
636 MS_CHECK_FALSE_MSG(model_buffer == nullptr, RET_ERROR, "model_buffer cannot be empty.");
637 auto *liteModel = reinterpret_cast<LiteModel *>(model);
638 auto size = liteModel->buf_size_;
639 model_buffer->ResizeData(size);
640
641 size_t out_size = model_buffer->DataSize();
642 int status = mindspore::lite::Model::Export(model, static_cast<char *>(model_buffer->MutableData()), &out_size);
643 if (out_size != size) {
644 MS_LOG(ERROR) << "model_buffer resize failed.";
645 return RET_ERROR;
646 }
647
648 return status;
649 }
650
SaveToFile()651 int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
652
SaveToBuffer()653 int TrainExport::SaveToBuffer() {
654 constexpr size_t kFbBuilderInitSize = 1024;
655 flatbuffers::FlatBufferBuilder builder(kFbBuilderInitSize);
656 auto offset = schema::MetaGraph::Pack(builder, meta_graph_);
657 builder.Finish(offset);
658 schema::FinishMetaGraphBuffer(builder, offset);
659 size_t size = builder.GetSize();
660 auto content = builder.GetBufferPointer();
661 MS_CHECK_FALSE_MSG(content == nullptr, RET_ERROR, "context cannot be empty.");
662 MS_CHECK_FALSE_MSG(model_buffer_ == nullptr, RET_ERROR, "context cannot be empty.");
663 model_buffer_->SetData(content, size);
664 return RET_OK;
665 }
SaveWeightsToFile(bool enable_fp16,const std::vector<std::string> & changeable_weights_name)666 int TrainExport::SaveWeightsToFile(bool enable_fp16, const std::vector<std::string> &changeable_weights_name) {
667 const auto &all_tensors = meta_graph_->allTensors;
668 std::ofstream weights(file_name_, std::ios::out | std::ios::trunc | std::ios::binary);
669 if (!weights.is_open()) {
670 MS_LOG(ERROR) << "Can not open weight file: " << file_name_;
671 return RET_ERROR;
672 }
673 for (auto &tensor : all_tensors) {
674 MS_CHECK_TRUE_MSG(tensor != nullptr, RET_NULL_PTR, "Exist tensor is a nullptr.");
675 if (tensor->data.empty()) {
676 continue;
677 }
678 if (std::find(changeable_weights_name.begin(), changeable_weights_name.end(), tensor->name) !=
679 changeable_weights_name.end()) {
680 auto shape = tensor->dims;
681 weights.write(reinterpret_cast<const char *>(shape.data()), shape.size() * sizeof(uint32_t));
682 if (weights.fail()) {
683 MS_LOG(ERROR) << "Write weights failed, weight file: " << file_name_;
684 weights.close();
685 return RET_ERROR;
686 }
687 }
688 if (!enable_fp16 || tensor->dataType != kNumberTypeFloat32) {
689 weights.write(reinterpret_cast<const char *>(tensor->data.data()), tensor->data.size());
690 if (weights.fail()) {
691 MS_LOG(ERROR) << "Write weights failed, weight file: " << file_name_;
692 weights.close();
693 return RET_ERROR;
694 }
695 } else {
696 std::vector<uint16_t> data_fp16(tensor->data.size() / sizeof(float));
697 #ifndef ENABLE_ARM
698 auto fp32_data = reinterpret_cast<const float *>(tensor->data.data());
699 auto fp16_data = reinterpret_cast<float16 *>(data_fp16.data());
700 CHECK_NULL_RETURN(fp32_data);
701 CHECK_NULL_RETURN(fp16_data);
702 for (size_t j = 0; j < data_fp16.size(); ++j) {
703 fp16_data[j] = float16(fp32_data[j]);
704 }
705 #else
706 Float32ToFloat16_fp16_handler(tensor->data.data(), data_fp16.data(), data_fp16.size(), true);
707 #endif
708 weights.write(reinterpret_cast<const char *>(data_fp16.data()), data_fp16.size() * sizeof(uint16_t));
709 if (weights.fail()) {
710 MS_LOG(ERROR) << "Write weights failed, weight file: " << file_name_;
711 weights.close();
712 return RET_ERROR;
713 }
714 }
715 }
716 weights.close();
717 #ifndef _MSC_VER
718 chmod(file_name_.c_str(), S_IRUSR | S_IWUSR);
719 #endif
720 return RET_OK;
721 }
722
IsInputTensor(const schema::TensorT & t)723 bool TrainExport::IsInputTensor(const schema::TensorT &t) {
724 int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
725 return ((t.data.size() == 0) && (total_dims != 0));
726 }
727
TrainModelFusion()728 int TrainExport::TrainModelFusion() {
729 GraphFusion graph_fusion;
730 auto status = graph_fusion.Run(meta_graph_);
731 if (status != RET_OK) {
732 return RET_ERROR;
733 }
734 return RET_OK;
735 }
736
TrainModelDrop()737 int TrainExport::TrainModelDrop() {
738 GraphDropout graph_dropout;
739 auto status = graph_dropout.Run(meta_graph_);
740 if (status != RET_OK) {
741 return RET_ERROR;
742 }
743 return RET_OK;
744 }
745
~TrainExport()746 TrainExport::~TrainExport() { delete meta_graph_; }
747 } // namespace lite
748 } // namespace mindspore
749