1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/train/train_export.h"
17 #include <unistd.h>
18 #include <sys/stat.h>
19 #include <fstream>
20 #include <utility>
21 #include <queue>
22 #include <algorithm>
23 #include <functional>
24 #include <map>
25 #include <set>
26 #include "schema/inner/model_generated.h"
27 #include "src/train/train_utils.h"
28 #include "src/common/quant_utils.h"
29 #include "tools/common/storage.h"
30
31 namespace mindspore {
32 namespace lite {
33 namespace {
34 constexpr static int kFmkVal = 3;
35 constexpr static int kTransformTensorDim = 4;
GetLinkedPostIdx(const schema::MetaGraphT & graphT,const size_t & tensor_idx)36 std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensor_idx) {
37 std::vector<size_t> post_node_idx;
38 for (size_t i = 0; i < graphT.nodes.size(); i++) {
39 auto &old_node = graphT.nodes.at(i);
40 if (old_node == nullptr) {
41 continue;
42 }
43 auto input_indexes = old_node->inputIndex;
44 if (IsContain<uint32_t>(input_indexes, tensor_idx)) {
45 post_node_idx.emplace_back(i);
46 }
47 }
48 return post_node_idx;
49 }
50
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const schema::CNodeT & node,const int output_index_idx=-1)51 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
52 const int output_index_idx = -1) {
53 std::vector<uint32_t> output_indexes;
54 if (output_index_idx == -1) {
55 output_indexes = node.outputIndex;
56 } else {
57 output_indexes.emplace_back(node.outputIndex.at(output_index_idx));
58 }
59 std::set<size_t> output_node_idx;
60 for (uint32_t outputIdx : output_indexes) {
61 auto linked_post_idx = GetLinkedPostIdx(graphT, outputIdx);
62 output_node_idx.insert(linked_post_idx.begin(), linked_post_idx.end());
63 }
64 std::vector<size_t> ret;
65 ret.insert(ret.end(), output_node_idx.begin(), output_node_idx.end());
66 return ret;
67 }
68 } // namespace
69
CreateData(const lite::Tensor * tensor)70 std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) {
71 uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data());
72 auto size = tensor->Size();
73 std::vector<uint8_t> data(tensor_data, tensor_data + size);
74 return data;
75 }
76
NeedQuantization(const lite::Tensor * t)77 bool TrainExport::NeedQuantization(const lite::Tensor *t) {
78 return ((quant_type_ == QT_WEIGHT && t->shape().size() > 1) ||
79 ((quant_type_ == QT_DEFAULT) && (t->quant_params().size() > 0) && (t->quant_params().at(0).inited)));
80 }
81
GetNodeQuantType(const kernel::LiteKernel * kernel)82 schema::QuantType TrainExport::GetNodeQuantType(const kernel::LiteKernel *kernel) {
83 if (std::any_of(kernel->in_tensors().cbegin(), kernel->in_tensors().cend(), [](const lite::Tensor *t) {
84 return (t->IsConst() && (t->quant_params().size() > 0) && (t->quant_params().at(0).inited));
85 })) {
86 return schema::QuantType_QUANT_WEIGHT;
87 }
88 return schema::QuantType_QUANT_NONE;
89 }
90
TagQuantizedNodes()91 void TrainExport::TagQuantizedNodes() {
92 if (quant_type_ == QT_WEIGHT) {
93 for (auto &node : meta_graph_->nodes) {
94 if (node->quantType != schema::QuantType_QUANT_WEIGHT) {
95 for (auto t_idx : node->inputIndex) {
96 if ((meta_graph_->allTensors.at(t_idx)->nodeType == NodeType_ValueNode) &&
97 (meta_graph_->allTensors.at(t_idx)->quantParams.size() > 0)) {
98 node->quantType = schema::QuantType_QUANT_WEIGHT;
99 }
100 }
101 }
102 }
103 }
104 }
105
QuantTensorData(schema::TensorT * dest_tensor,const lite::Tensor * src_tensor)106 int TrainExport::QuantTensorData(schema::TensorT *dest_tensor, const lite::Tensor *src_tensor) {
107 int channels = 1;
108 int bit_num = 8;
109
110 if (src_tensor->quant_params().size() > 0) {
111 channels = src_tensor->quant_params().size();
112 bit_num = src_tensor->quant_params().at(0).bitNum;
113 }
114 if (channels < 1) {
115 MS_LOG(ERROR) << "Quant Params is empty";
116 return RET_ERROR;
117 }
118 int quant_max = QuantMax(bit_num, kNumberTypeInt8);
119 int quant_min = QuantMin(bit_num, kNumberTypeInt8);
120 std::vector<int8_t> data(src_tensor->ElementsNum());
121 std::vector<schema::QuantParamT> quant_params;
122
123 STATUS ret = RET_OK;
124 if (channels == kPerTensor) {
125 ret = DoPerLayerQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
126 &(quant_params), quant_max, quant_min, bit_num, false, &data);
127 } else {
128 bool channel_at_first = (src_tensor->shape().at(0) == channels);
129 ret = DoPerChannelQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
130 schema::QuantType_WeightQuant, &(quant_params), quant_max, quant_min, bit_num,
131 false, &data, channels, channel_at_first);
132 }
133 if (ret == RET_QUANT_CONTINUE) {
134 MS_LOG(DEBUG) << "No Need to quant per channel";
135 return RET_OK;
136 }
137 if (ret == RET_ERROR) {
138 MS_LOG(ERROR) << "QuantTensorData error, channels = " << channels;
139 return ret;
140 }
141 if (quant_params.empty()) {
142 MS_LOG(ERROR) << "quant_params empty";
143 return RET_ERROR;
144 }
145 dest_tensor->data = std::vector<uint8_t>(data.data(), data.data() + data.size());
146 dest_tensor->dataType = kNumberTypeInt8;
147 dest_tensor->quantParams.clear();
148 for (auto quant_param : quant_params) {
149 dest_tensor->quantParams.emplace_back(std::make_unique<schema::QuantParamT>(quant_param));
150 }
151
152 return RET_OK;
153 }
154
CreateTensor(const mindspore::lite::Tensor * tensor,schema::Tensor * scTensor)155 std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(const mindspore::lite::Tensor *tensor,
156 schema::Tensor *scTensor) {
157 auto tensorT = std::make_unique<schema::TensorT>();
158 tensorT->nodeType = scTensor->nodeType();
159 tensorT->dims = tensor->shape();
160 tensorT->format = static_cast<schema::Format>(tensor->format());
161 tensorT->name = tensor->tensor_name();
162 tensorT->refCount = 0;
163 tensorT->offset = 0;
164 tensorT->dataType = tensor->data_type();
165 tensorT->enableHuffmanCode = false;
166 if ((tensorT->nodeType == NodeType_ValueNode) && (scTensor->data() != nullptr) && (scTensor->data()->size() > 0)) {
167 if (NeedQuantization(tensor)) {
168 QuantTensorData(tensorT.get(), tensor);
169 } else {
170 tensorT->data = CreateData(tensor);
171 }
172 }
173 tensorT->quantClusters = tensor->quant_clusters();
174 return tensorT;
175 }
176
FindNode(const mindspore::kernel::LiteKernel * kernel,const Model * model)177 Model::Node *TrainExport::FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model) {
178 auto nodes = model->all_nodes_;
179 auto it = std::find_if(nodes.begin(), nodes.end(),
180 [&kernel](mindspore::lite::Model::Node *n) { return (kernel->name() == n->name_); });
181 if (it == nodes.end()) {
182 return nullptr;
183 }
184 return *it;
185 }
186
CreateAndAddCNode(const mindspore::kernel::LiteKernel * kernel,std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,const Model * model)187 int TrainExport::CreateAndAddCNode(const mindspore::kernel::LiteKernel *kernel, std::vector<uint32_t> inputIndex,
188 std::vector<uint32_t> outputIndex, const Model *model) {
189 auto cnode = CreateCNode(kernel, inputIndex, outputIndex, model);
190 if (cnode == nullptr) {
191 MS_LOG(ERROR) << "failed to create cnode";
192 return RET_ERROR;
193 }
194 meta_graph_->nodes.emplace_back(std::move(cnode));
195 if (!meta_graph_->subGraph.empty()) {
196 meta_graph_->subGraph[0]->nodeIndices.push_back(meta_graph_->nodes.size() - 1);
197 }
198 return RET_OK;
199 }
200
CreateCNode(const mindspore::kernel::LiteKernel * kernel,std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,const Model * model)201 std::unique_ptr<schema::CNodeT> TrainExport::CreateCNode(const mindspore::kernel::LiteKernel *kernel,
202 std::vector<uint32_t> inputIndex,
203 std::vector<uint32_t> outputIndex, const Model *model) {
204 auto cnodeT = std::make_unique<schema::CNodeT>();
205 if (cnodeT == nullptr) {
206 MS_LOG(ERROR) << " cannot allocate node";
207 return nullptr;
208 }
209 cnodeT->inputIndex = inputIndex;
210 cnodeT->outputIndex = outputIndex;
211 cnodeT->name = kernel->name();
212 cnodeT->quantType = GetNodeQuantType(kernel);
213 // find kernel in model
214 auto *node = FindNode(kernel, model);
215 if (node == nullptr) {
216 MS_LOG(ERROR) << "cannot find kernel " + kernel->name() + " in model";
217 return nullptr;
218 }
219 auto primitive = reinterpret_cast<schema::Primitive *>(const_cast<void *>(node->primitive_));
220 cnodeT->primitive = std::unique_ptr<schema::PrimitiveT>(primitive->UnPack());
221 return cnodeT;
222 }
223
LoadModel(void * buf,size_t buf_size)224 int TrainExport::LoadModel(void *buf, size_t buf_size) {
225 flatbuffers::Verifier verify((const uint8_t *)buf, buf_size);
226 if (!schema::VerifyMetaGraphBuffer(verify)) {
227 MS_LOG(ERROR) << "model flatbuffer verify fail";
228 return RET_ERROR;
229 }
230 meta_graph_ = schema::GetMetaGraph(buf)->UnPack();
231 meta_graph_->outputIndex.clear();
232 if (!meta_graph_->subGraph.empty()) {
233 meta_graph_->subGraph[0]->outputIndices.clear();
234 }
235 return RET_OK;
236 }
237
CreateTransformTensor(size_t id)238 std::unique_ptr<schema::TensorT> TrainExport::CreateTransformTensor(size_t id) {
239 auto &scTensor = meta_graph_->allTensors.at(id);
240 auto tensorT = std::make_unique<schema::TensorT>();
241 if (tensorT == nullptr) {
242 MS_LOG(ERROR) << "Could not create tensor ";
243 return nullptr;
244 }
245 tensorT->nodeType = scTensor->nodeType;
246 tensorT->dataType = scTensor->dataType;
247 std::vector<int32_t> dims;
248 std::vector<int32_t> val = {0, 2, 3, 1};
249 if (scTensor->dims.size() == kTransformTensorDim) {
250 for (size_t i = 0; i < val.size(); i++) {
251 dims.push_back(scTensor->dims.at(val[i]));
252 }
253 tensorT->dims = dims;
254 } else {
255 tensorT->dims = scTensor->dims;
256 }
257 tensorT->format = schema::Format_NHWC;
258 tensorT->name = scTensor->name + "_post";
259 tensorT->refCount = 0;
260 tensorT->offset = 0;
261 tensorT->enableHuffmanCode = false;
262 return tensorT;
263 }
264
CreateTransformConst(size_t last_id)265 std::unique_ptr<schema::TensorT> TrainExport::CreateTransformConst(size_t last_id) {
266 auto tensorT = std::make_unique<schema::TensorT>();
267 if (tensorT == nullptr) {
268 MS_LOG(ERROR) << "Could not create tensor ";
269 return nullptr;
270 }
271 tensorT->nodeType = lite::NodeType_ValueNode;
272 tensorT->dataType = TypeId::kNumberTypeInt32;
273 tensorT->dims = {kTransformTensorDim};
274 tensorT->format = schema::Format_NCHW;
275 tensorT->name = "const-" + std::to_string(last_id);
276 tensorT->refCount = 0;
277 tensorT->offset = 0;
278 tensorT->enableHuffmanCode = false;
279 int32_t val[] = {0, 2, 3, 1};
280 uint8_t *valp = reinterpret_cast<uint8_t *>(val);
281 tensorT->data = std::vector<uint8_t>(valp, valp + sizeof(val));
282 return tensorT;
283 }
284
CreateTransformNode(std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,size_t id)285 std::unique_ptr<schema::CNodeT> TrainExport::CreateTransformNode(std::vector<uint32_t> inputIndex,
286 std::vector<uint32_t> outputIndex, size_t id) {
287 auto cnodeT = std::make_unique<schema::CNodeT>();
288 if (cnodeT == nullptr) {
289 MS_LOG(ERROR) << "cannot allocate node";
290 return nullptr;
291 }
292 cnodeT->inputIndex = inputIndex;
293 cnodeT->outputIndex = outputIndex;
294 cnodeT->name = "transpose-" + std::to_string(id);
295 cnodeT->quantType = schema::QuantType_QUANT_NONE;
296 cnodeT->primitive = std::make_unique<schema::PrimitiveT>();
297 cnodeT->primitive->value.type = schema::PrimitiveType_Transpose;
298 return cnodeT;
299 }
300
AddTransformNode()301 int TrainExport::AddTransformNode() {
302 std::unordered_map<size_t, size_t> reconnect;
303 size_t last_id = meta_graph_->allTensors.size();
304 size_t last_node = meta_graph_->nodes.size();
305 for (auto it : connect_) {
306 auto tensorConst = CreateTransformConst(last_id);
307 if (tensorConst == nullptr) {
308 MS_LOG(ERROR) << "error in create tensor";
309 return RET_ERROR;
310 }
311 meta_graph_->allTensors.emplace_back(std::move(tensorConst)); // last_id
312 if (!meta_graph_->subGraph.empty()) {
313 meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
314 }
315 auto tensorT = CreateTransformTensor(it.second);
316 if (tensorT == nullptr) {
317 MS_LOG(ERROR) << "error in create tensor";
318 return RET_ERROR;
319 }
320 meta_graph_->allTensors.emplace_back(std::move(tensorT)); // last_id + 1
321 if (!meta_graph_->subGraph.empty()) {
322 meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
323 }
324 std::vector<uint32_t> in_idx = {static_cast<uint32_t>(it.second), static_cast<uint32_t>(last_id)};
325 std::vector<uint32_t> out_idx = {static_cast<uint32_t>(last_id + 1)};
326 reconnect[it.first] = last_id + 1;
327 auto cnode = CreateTransformNode(in_idx, out_idx, last_node);
328 if (cnode == nullptr) {
329 MS_LOG(ERROR) << "error in node creation";
330 return RET_ERROR;
331 }
332 meta_graph_->nodes.emplace_back(std::move(cnode));
333 if (!meta_graph_->subGraph.empty()) {
334 meta_graph_->subGraph[0]->nodeIndices.push_back(meta_graph_->nodes.size() - 1);
335 }
336 }
337 connect_ = reconnect;
338 return RET_OK;
339 }
340
PrepareRemap(int offset)341 void TrainExport::PrepareRemap(int offset) {
342 for (auto it : connect_) {
343 remap_[it.first + offset] = it.second;
344 }
345 }
346
ExportNet(const std::vector<mindspore::kernel::LiteKernel * > & kernels,const std::vector<mindspore::lite::Tensor * > & tensors,const std::vector<std::string> & output_names,const Model * model,QuantizationType quant_type)347 int TrainExport::ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
348 const std::vector<mindspore::lite::Tensor *> &tensors,
349 const std::vector<std::string> &output_names, const Model *model,
350 QuantizationType quant_type) {
351 std::vector<size_t> map_index;
352 std::set<size_t> out_set;
353 int offset = meta_graph_->allTensors.size();
354 int tensor_idx = offset;
355 quant_type_ = quant_type;
356 if (meta_graph_ == nullptr) {
357 int status = ExportInit(model->name_, model->version_);
358 if (status != RET_OK) {
359 return status;
360 }
361 }
362 PrepareRemap(offset);
363
364 for (const auto kernel : kernels) {
365 std::vector<uint32_t> in_idx, out_idx;
366 for (const auto tensor : kernel->in_tensors()) {
367 size_t id = TSFindTensor(tensors, tensor) + offset;
368 if (id == tensors.size()) {
369 MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
370 return RET_ERROR;
371 }
372 auto it = remap_.find(id);
373 if (it == remap_.end()) {
374 remap_[id] = tensor_idx;
375 in_idx.push_back(tensor_idx);
376 map_index.push_back(id);
377 tensor_idx++;
378 } else {
379 in_idx.push_back(it->second);
380 }
381 }
382 for (const auto tensor : kernel->out_tensors()) {
383 size_t id = TSFindTensor(tensors, tensor) + offset;
384 if (id == tensors.size()) {
385 MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
386 return RET_ERROR;
387 }
388 auto it = remap_.find(id);
389 if (it == remap_.end()) {
390 remap_[id] = tensor_idx;
391 map_index.push_back(id);
392 out_idx.push_back(tensor_idx);
393 out_set.insert(tensor_idx);
394 tensor_idx++;
395 } else {
396 out_idx.push_back(it->second);
397 out_set.insert(it->second);
398 }
399 }
400 auto ret = CreateAndAddCNode(kernel, in_idx, out_idx, model);
401 if (ret != RET_OK) {
402 MS_LOG(ERROR) << "failed to create cnode";
403 return ret;
404 }
405 }
406 for (auto id : map_index) {
407 size_t pid = id - offset;
408 mindspore::lite::Tensor *tensor = tensors.at(pid);
409 schema::Tensor *scTensor = model->all_tensors_.at(pid);
410 auto tensorT = CreateTensor(tensor, scTensor);
411 if (tensorT == nullptr) {
412 MS_LOG(ERROR) << "error in tensor creation";
413 return RET_ERROR;
414 }
415 if (out_set.find(remap_[id]) == out_set.end()) {
416 if (IsInputTensor(*tensorT)) {
417 meta_graph_->inputIndex.push_back(remap_[id]);
418 if (!meta_graph_->subGraph.empty()) {
419 meta_graph_->subGraph[0]->inputIndices.push_back(remap_[id]);
420 }
421 }
422 }
423 // find output tensor
424 if (std::find(output_names.begin(), output_names.end(), tensor->tensor_name()) != output_names.end()) {
425 meta_graph_->outputIndex.push_back(remap_[id]);
426 if (!meta_graph_->subGraph.empty()) {
427 meta_graph_->subGraph[0]->outputIndices.push_back(remap_[id]);
428 }
429 }
430 meta_graph_->allTensors.emplace_back(std::move(tensorT));
431 if (!meta_graph_->subGraph.empty()) {
432 meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
433 }
434 }
435 TagQuantizedNodes(); // do another loop to mark QUANT_WEIGHT_NODES
436 auto status = TopologicalSort();
437 if (status != RET_OK) {
438 MS_LOG(ERROR) << "TopologicalSort failed.";
439 return RET_ERROR;
440 }
441
442 return RET_OK;
443 }
444
TopologicalSort()445 int TrainExport::TopologicalSort() {
446 MS_ASSERT(meta_graph_ != nullptr);
447 std::vector<std::unique_ptr<schema::CNodeT>> new_nodes;
448 std::vector<size_t> sinked_tensor_idxes;
449 for (auto &subgraph : meta_graph_->subGraph) {
450 std::copy(subgraph->inputIndices.begin(), subgraph->inputIndices.end(), std::back_inserter(sinked_tensor_idxes));
451 }
452 // put all const tensor index into sinked_tensor_idxes
453 for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) {
454 if (meta_graph_->allTensors.at(i)->nodeType == NodeType_ValueNode) {
455 sinked_tensor_idxes.push_back(i);
456 }
457 }
458 auto &old_nodes = meta_graph_->nodes;
459 std::queue<std::unique_ptr<schema::CNodeT>> op_queue;
460 // put all none depend node into queue
461 for (size_t i = 0; i < meta_graph_->subGraph.size(); i++) {
462 std::vector<unsigned int> new_subgraph_node_indices = {};
463 auto subgraph_node_indices = meta_graph_->subGraph[i]->nodeIndices;
464
465 for (size_t j = 0; j < subgraph_node_indices.size(); j++) {
466 auto &node = old_nodes[subgraph_node_indices[j]];
467 if (IsNodeNonDepend(node, sinked_tensor_idxes)) {
468 sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
469 op_queue.push(std::move(node));
470 }
471 }
472 while (!op_queue.empty()) {
473 auto &node = op_queue.front();
474 auto post_node_idxes = GetOutputNodeIdx(*meta_graph_, *(node.get()));
475 sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
476 for (auto post_node_idx : post_node_idxes) {
477 if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) {
478 auto &post_node = old_nodes.at(post_node_idx);
479 // check if post_node is non-depended
480 if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) {
481 op_queue.push(std::move(post_node));
482 }
483 }
484 }
485 new_nodes.emplace_back(std::move(node));
486 new_subgraph_node_indices.push_back(new_nodes.size() - 1);
487 op_queue.pop();
488 }
489 meta_graph_->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices);
490 }
491 if (new_nodes.size() != old_nodes.size()) {
492 MS_LOG(ERROR) << "Unknown error in TopologicalSort, old_nodes size: " << old_nodes.size()
493 << ", new_nodes size: " << new_nodes.size();
494 return RET_ERROR;
495 }
496 meta_graph_->nodes.swap(new_nodes);
497 return RET_OK;
498 }
499
IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> & node,const std::vector<size_t> & sinked_tensor_idxes)500 bool TrainExport::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
501 const std::vector<size_t> &sinked_tensor_idxes) {
502 MS_ASSERT(node != nullptr);
503 return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
504 [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); });
505 }
506
ExportInit(const std::string model_name,std::string version)507 int TrainExport::ExportInit(const std::string model_name, std::string version) {
508 meta_graph_ = new (std::nothrow) schema::MetaGraphT();
509 if (meta_graph_ == nullptr) {
510 MS_LOG(ERROR) << "cannot allocate meta_graph";
511 return RET_ERROR;
512 }
513 auto sub_graph = std::make_unique<schema::SubGraphT>();
514 if (sub_graph == nullptr) {
515 MS_LOG(ERROR) << "cannot allocate SubGraphT";
516 return RET_ERROR;
517 }
518 sub_graph->name = model_name + "_subgraph";
519 meta_graph_->subGraph.emplace_back(std::move(sub_graph));
520 meta_graph_->fmkType = kFmkVal;
521 meta_graph_->name = model_name;
522 meta_graph_->version = version;
523 return RET_OK;
524 }
525
SaveToFile()526 int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
527
IsInputTensor(const schema::TensorT & t)528 int TrainExport::IsInputTensor(const schema::TensorT &t) {
529 int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
530 return ((t.data.size() == 0) && (total_dims != 0));
531 }
532
~TrainExport()533 TrainExport::~TrainExport() { delete meta_graph_; }
534 } // namespace lite
535 } // namespace mindspore
536