• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/train/train_export.h"
17 #include <unistd.h>
18 #include <sys/stat.h>
19 #include <fstream>
20 #include <utility>
21 #include <queue>
22 #include <algorithm>
23 #include <functional>
24 #include <map>
25 #include <set>
26 #include "schema/inner/model_generated.h"
27 #include "src/train/train_utils.h"
28 #include "src/common/quant_utils.h"
29 #include "tools/common/storage.h"
30 
31 namespace mindspore {
32 namespace lite {
33 namespace {
34 constexpr static int kFmkVal = 3;
35 constexpr static int kTransformTensorDim = 4;
GetLinkedPostIdx(const schema::MetaGraphT & graphT,const size_t & tensor_idx)36 std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensor_idx) {
37   std::vector<size_t> post_node_idx;
38   for (size_t i = 0; i < graphT.nodes.size(); i++) {
39     auto &old_node = graphT.nodes.at(i);
40     if (old_node == nullptr) {
41       continue;
42     }
43     auto input_indexes = old_node->inputIndex;
44     if (IsContain<uint32_t>(input_indexes, tensor_idx)) {
45       post_node_idx.emplace_back(i);
46     }
47   }
48   return post_node_idx;
49 }
50 
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const schema::CNodeT & node,const int output_index_idx=-1)51 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
52                                      const int output_index_idx = -1) {
53   std::vector<uint32_t> output_indexes;
54   if (output_index_idx == -1) {
55     output_indexes = node.outputIndex;
56   } else {
57     output_indexes.emplace_back(node.outputIndex.at(output_index_idx));
58   }
59   std::set<size_t> output_node_idx;
60   for (uint32_t outputIdx : output_indexes) {
61     auto linked_post_idx = GetLinkedPostIdx(graphT, outputIdx);
62     output_node_idx.insert(linked_post_idx.begin(), linked_post_idx.end());
63   }
64   std::vector<size_t> ret;
65   ret.insert(ret.end(), output_node_idx.begin(), output_node_idx.end());
66   return ret;
67 }
68 }  // namespace
69 
CreateData(const lite::Tensor * tensor)70 std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) {
71   uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data());
72   auto size = tensor->Size();
73   std::vector<uint8_t> data(tensor_data, tensor_data + size);
74   return data;
75 }
76 
NeedQuantization(const lite::Tensor * t)77 bool TrainExport::NeedQuantization(const lite::Tensor *t) {
78   return ((quant_type_ == QT_WEIGHT && t->shape().size() > 1) ||
79           ((quant_type_ == QT_DEFAULT) && (t->quant_params().size() > 0) && (t->quant_params().at(0).inited)));
80 }
81 
GetNodeQuantType(const kernel::LiteKernel * kernel)82 schema::QuantType TrainExport::GetNodeQuantType(const kernel::LiteKernel *kernel) {
83   if (std::any_of(kernel->in_tensors().cbegin(), kernel->in_tensors().cend(), [](const lite::Tensor *t) {
84         return (t->IsConst() && (t->quant_params().size() > 0) && (t->quant_params().at(0).inited));
85       })) {
86     return schema::QuantType_QUANT_WEIGHT;
87   }
88   return schema::QuantType_QUANT_NONE;
89 }
90 
TagQuantizedNodes()91 void TrainExport::TagQuantizedNodes() {
92   if (quant_type_ == QT_WEIGHT) {
93     for (auto &node : meta_graph_->nodes) {
94       if (node->quantType != schema::QuantType_QUANT_WEIGHT) {
95         for (auto t_idx : node->inputIndex) {
96           if ((meta_graph_->allTensors.at(t_idx)->nodeType == NodeType_ValueNode) &&
97               (meta_graph_->allTensors.at(t_idx)->quantParams.size() > 0)) {
98             node->quantType = schema::QuantType_QUANT_WEIGHT;
99           }
100         }
101       }
102     }
103   }
104 }
105 
QuantTensorData(schema::TensorT * dest_tensor,const lite::Tensor * src_tensor)106 int TrainExport::QuantTensorData(schema::TensorT *dest_tensor, const lite::Tensor *src_tensor) {
107   int channels = 1;
108   int bit_num = 8;
109 
110   if (src_tensor->quant_params().size() > 0) {
111     channels = src_tensor->quant_params().size();
112     bit_num = src_tensor->quant_params().at(0).bitNum;
113   }
114   if (channels < 1) {
115     MS_LOG(ERROR) << "Quant Params is empty";
116     return RET_ERROR;
117   }
118   int quant_max = QuantMax(bit_num, kNumberTypeInt8);
119   int quant_min = QuantMin(bit_num, kNumberTypeInt8);
120   std::vector<int8_t> data(src_tensor->ElementsNum());
121   std::vector<schema::QuantParamT> quant_params;
122 
123   STATUS ret = RET_OK;
124   if (channels == kPerTensor) {
125     ret = DoPerLayerQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
126                                   &(quant_params), quant_max, quant_min, bit_num, false, &data);
127   } else {
128     bool channel_at_first = (src_tensor->shape().at(0) == channels);
129     ret = DoPerChannelQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
130                                     schema::QuantType_WeightQuant, &(quant_params), quant_max, quant_min, bit_num,
131                                     false, &data, channels, channel_at_first);
132   }
133   if (ret == RET_QUANT_CONTINUE) {
134     MS_LOG(DEBUG) << "No Need to quant per channel";
135     return RET_OK;
136   }
137   if (ret == RET_ERROR) {
138     MS_LOG(ERROR) << "QuantTensorData error,  channels = " << channels;
139     return ret;
140   }
141   if (quant_params.empty()) {
142     MS_LOG(ERROR) << "quant_params empty";
143     return RET_ERROR;
144   }
145   dest_tensor->data = std::vector<uint8_t>(data.data(), data.data() + data.size());
146   dest_tensor->dataType = kNumberTypeInt8;
147   dest_tensor->quantParams.clear();
148   for (auto quant_param : quant_params) {
149     dest_tensor->quantParams.emplace_back(std::make_unique<schema::QuantParamT>(quant_param));
150   }
151 
152   return RET_OK;
153 }
154 
CreateTensor(const mindspore::lite::Tensor * tensor,schema::Tensor * scTensor)155 std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(const mindspore::lite::Tensor *tensor,
156                                                            schema::Tensor *scTensor) {
157   auto tensorT = std::make_unique<schema::TensorT>();
158   tensorT->nodeType = scTensor->nodeType();
159   tensorT->dims = tensor->shape();
160   tensorT->format = static_cast<schema::Format>(tensor->format());
161   tensorT->name = tensor->tensor_name();
162   tensorT->refCount = 0;
163   tensorT->offset = 0;
164   tensorT->dataType = tensor->data_type();
165   tensorT->enableHuffmanCode = false;
166   if ((tensorT->nodeType == NodeType_ValueNode) && (scTensor->data() != nullptr) && (scTensor->data()->size() > 0)) {
167     if (NeedQuantization(tensor)) {
168       QuantTensorData(tensorT.get(), tensor);
169     } else {
170       tensorT->data = CreateData(tensor);
171     }
172   }
173   tensorT->quantClusters = tensor->quant_clusters();
174   return tensorT;
175 }
176 
FindNode(const mindspore::kernel::LiteKernel * kernel,const Model * model)177 Model::Node *TrainExport::FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model) {
178   auto nodes = model->all_nodes_;
179   auto it = std::find_if(nodes.begin(), nodes.end(),
180                          [&kernel](mindspore::lite::Model::Node *n) { return (kernel->name() == n->name_); });
181   if (it == nodes.end()) {
182     return nullptr;
183   }
184   return *it;
185 }
186 
CreateAndAddCNode(const mindspore::kernel::LiteKernel * kernel,std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,const Model * model)187 int TrainExport::CreateAndAddCNode(const mindspore::kernel::LiteKernel *kernel, std::vector<uint32_t> inputIndex,
188                                    std::vector<uint32_t> outputIndex, const Model *model) {
189   auto cnode = CreateCNode(kernel, inputIndex, outputIndex, model);
190   if (cnode == nullptr) {
191     MS_LOG(ERROR) << "failed to create cnode";
192     return RET_ERROR;
193   }
194   meta_graph_->nodes.emplace_back(std::move(cnode));
195   if (!meta_graph_->subGraph.empty()) {
196     meta_graph_->subGraph[0]->nodeIndices.push_back(meta_graph_->nodes.size() - 1);
197   }
198   return RET_OK;
199 }
200 
CreateCNode(const mindspore::kernel::LiteKernel * kernel,std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,const Model * model)201 std::unique_ptr<schema::CNodeT> TrainExport::CreateCNode(const mindspore::kernel::LiteKernel *kernel,
202                                                          std::vector<uint32_t> inputIndex,
203                                                          std::vector<uint32_t> outputIndex, const Model *model) {
204   auto cnodeT = std::make_unique<schema::CNodeT>();
205   if (cnodeT == nullptr) {
206     MS_LOG(ERROR) << " cannot allocate node";
207     return nullptr;
208   }
209   cnodeT->inputIndex = inputIndex;
210   cnodeT->outputIndex = outputIndex;
211   cnodeT->name = kernel->name();
212   cnodeT->quantType = GetNodeQuantType(kernel);
213   // find kernel in model
214   auto *node = FindNode(kernel, model);
215   if (node == nullptr) {
216     MS_LOG(ERROR) << "cannot find kernel " + kernel->name() + " in model";
217     return nullptr;
218   }
219   auto primitive = reinterpret_cast<schema::Primitive *>(const_cast<void *>(node->primitive_));
220   cnodeT->primitive = std::unique_ptr<schema::PrimitiveT>(primitive->UnPack());
221   return cnodeT;
222 }
223 
LoadModel(void * buf,size_t buf_size)224 int TrainExport::LoadModel(void *buf, size_t buf_size) {
225   flatbuffers::Verifier verify((const uint8_t *)buf, buf_size);
226   if (!schema::VerifyMetaGraphBuffer(verify)) {
227     MS_LOG(ERROR) << "model flatbuffer verify fail";
228     return RET_ERROR;
229   }
230   meta_graph_ = schema::GetMetaGraph(buf)->UnPack();
231   meta_graph_->outputIndex.clear();
232   if (!meta_graph_->subGraph.empty()) {
233     meta_graph_->subGraph[0]->outputIndices.clear();
234   }
235   return RET_OK;
236 }
237 
CreateTransformTensor(size_t id)238 std::unique_ptr<schema::TensorT> TrainExport::CreateTransformTensor(size_t id) {
239   auto &scTensor = meta_graph_->allTensors.at(id);
240   auto tensorT = std::make_unique<schema::TensorT>();
241   if (tensorT == nullptr) {
242     MS_LOG(ERROR) << "Could not create tensor ";
243     return nullptr;
244   }
245   tensorT->nodeType = scTensor->nodeType;
246   tensorT->dataType = scTensor->dataType;
247   std::vector<int32_t> dims;
248   std::vector<int32_t> val = {0, 2, 3, 1};
249   if (scTensor->dims.size() == kTransformTensorDim) {
250     for (size_t i = 0; i < val.size(); i++) {
251       dims.push_back(scTensor->dims.at(val[i]));
252     }
253     tensorT->dims = dims;
254   } else {
255     tensorT->dims = scTensor->dims;
256   }
257   tensorT->format = schema::Format_NHWC;
258   tensorT->name = scTensor->name + "_post";
259   tensorT->refCount = 0;
260   tensorT->offset = 0;
261   tensorT->enableHuffmanCode = false;
262   return tensorT;
263 }
264 
CreateTransformConst(size_t last_id)265 std::unique_ptr<schema::TensorT> TrainExport::CreateTransformConst(size_t last_id) {
266   auto tensorT = std::make_unique<schema::TensorT>();
267   if (tensorT == nullptr) {
268     MS_LOG(ERROR) << "Could not create tensor ";
269     return nullptr;
270   }
271   tensorT->nodeType = lite::NodeType_ValueNode;
272   tensorT->dataType = TypeId::kNumberTypeInt32;
273   tensorT->dims = {kTransformTensorDim};
274   tensorT->format = schema::Format_NCHW;
275   tensorT->name = "const-" + std::to_string(last_id);
276   tensorT->refCount = 0;
277   tensorT->offset = 0;
278   tensorT->enableHuffmanCode = false;
279   int32_t val[] = {0, 2, 3, 1};
280   uint8_t *valp = reinterpret_cast<uint8_t *>(val);
281   tensorT->data = std::vector<uint8_t>(valp, valp + sizeof(val));
282   return tensorT;
283 }
284 
CreateTransformNode(std::vector<uint32_t> inputIndex,std::vector<uint32_t> outputIndex,size_t id)285 std::unique_ptr<schema::CNodeT> TrainExport::CreateTransformNode(std::vector<uint32_t> inputIndex,
286                                                                  std::vector<uint32_t> outputIndex, size_t id) {
287   auto cnodeT = std::make_unique<schema::CNodeT>();
288   if (cnodeT == nullptr) {
289     MS_LOG(ERROR) << "cannot allocate node";
290     return nullptr;
291   }
292   cnodeT->inputIndex = inputIndex;
293   cnodeT->outputIndex = outputIndex;
294   cnodeT->name = "transpose-" + std::to_string(id);
295   cnodeT->quantType = schema::QuantType_QUANT_NONE;
296   cnodeT->primitive = std::make_unique<schema::PrimitiveT>();
297   cnodeT->primitive->value.type = schema::PrimitiveType_Transpose;
298   return cnodeT;
299 }
300 
AddTransformNode()301 int TrainExport::AddTransformNode() {
302   std::unordered_map<size_t, size_t> reconnect;
303   size_t last_id = meta_graph_->allTensors.size();
304   size_t last_node = meta_graph_->nodes.size();
305   for (auto it : connect_) {
306     auto tensorConst = CreateTransformConst(last_id);
307     if (tensorConst == nullptr) {
308       MS_LOG(ERROR) << "error in create tensor";
309       return RET_ERROR;
310     }
311     meta_graph_->allTensors.emplace_back(std::move(tensorConst));  // last_id
312     if (!meta_graph_->subGraph.empty()) {
313       meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
314     }
315     auto tensorT = CreateTransformTensor(it.second);
316     if (tensorT == nullptr) {
317       MS_LOG(ERROR) << "error in create tensor";
318       return RET_ERROR;
319     }
320     meta_graph_->allTensors.emplace_back(std::move(tensorT));  // last_id + 1
321     if (!meta_graph_->subGraph.empty()) {
322       meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
323     }
324     std::vector<uint32_t> in_idx = {static_cast<uint32_t>(it.second), static_cast<uint32_t>(last_id)};
325     std::vector<uint32_t> out_idx = {static_cast<uint32_t>(last_id + 1)};
326     reconnect[it.first] = last_id + 1;
327     auto cnode = CreateTransformNode(in_idx, out_idx, last_node);
328     if (cnode == nullptr) {
329       MS_LOG(ERROR) << "error in node creation";
330       return RET_ERROR;
331     }
332     meta_graph_->nodes.emplace_back(std::move(cnode));
333     if (!meta_graph_->subGraph.empty()) {
334       meta_graph_->subGraph[0]->nodeIndices.push_back(meta_graph_->nodes.size() - 1);
335     }
336   }
337   connect_ = reconnect;
338   return RET_OK;
339 }
340 
PrepareRemap(int offset)341 void TrainExport::PrepareRemap(int offset) {
342   for (auto it : connect_) {
343     remap_[it.first + offset] = it.second;
344   }
345 }
346 
ExportNet(const std::vector<mindspore::kernel::LiteKernel * > & kernels,const std::vector<mindspore::lite::Tensor * > & tensors,const std::vector<std::string> & output_names,const Model * model,QuantizationType quant_type)347 int TrainExport::ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
348                            const std::vector<mindspore::lite::Tensor *> &tensors,
349                            const std::vector<std::string> &output_names, const Model *model,
350                            QuantizationType quant_type) {
351   std::vector<size_t> map_index;
352   std::set<size_t> out_set;
353   int offset = meta_graph_->allTensors.size();
354   int tensor_idx = offset;
355   quant_type_ = quant_type;
356   if (meta_graph_ == nullptr) {
357     int status = ExportInit(model->name_, model->version_);
358     if (status != RET_OK) {
359       return status;
360     }
361   }
362   PrepareRemap(offset);
363 
364   for (const auto kernel : kernels) {
365     std::vector<uint32_t> in_idx, out_idx;
366     for (const auto tensor : kernel->in_tensors()) {
367       size_t id = TSFindTensor(tensors, tensor) + offset;
368       if (id == tensors.size()) {
369         MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
370         return RET_ERROR;
371       }
372       auto it = remap_.find(id);
373       if (it == remap_.end()) {
374         remap_[id] = tensor_idx;
375         in_idx.push_back(tensor_idx);
376         map_index.push_back(id);
377         tensor_idx++;
378       } else {
379         in_idx.push_back(it->second);
380       }
381     }
382     for (const auto tensor : kernel->out_tensors()) {
383       size_t id = TSFindTensor(tensors, tensor) + offset;
384       if (id == tensors.size()) {
385         MS_LOG(ERROR) << "cannot find tensor " + tensor->ToString() + " in model";
386         return RET_ERROR;
387       }
388       auto it = remap_.find(id);
389       if (it == remap_.end()) {
390         remap_[id] = tensor_idx;
391         map_index.push_back(id);
392         out_idx.push_back(tensor_idx);
393         out_set.insert(tensor_idx);
394         tensor_idx++;
395       } else {
396         out_idx.push_back(it->second);
397         out_set.insert(it->second);
398       }
399     }
400     auto ret = CreateAndAddCNode(kernel, in_idx, out_idx, model);
401     if (ret != RET_OK) {
402       MS_LOG(ERROR) << "failed to create cnode";
403       return ret;
404     }
405   }
406   for (auto id : map_index) {
407     size_t pid = id - offset;
408     mindspore::lite::Tensor *tensor = tensors.at(pid);
409     schema::Tensor *scTensor = model->all_tensors_.at(pid);
410     auto tensorT = CreateTensor(tensor, scTensor);
411     if (tensorT == nullptr) {
412       MS_LOG(ERROR) << "error in tensor creation";
413       return RET_ERROR;
414     }
415     if (out_set.find(remap_[id]) == out_set.end()) {
416       if (IsInputTensor(*tensorT)) {
417         meta_graph_->inputIndex.push_back(remap_[id]);
418         if (!meta_graph_->subGraph.empty()) {
419           meta_graph_->subGraph[0]->inputIndices.push_back(remap_[id]);
420         }
421       }
422     }
423     // find output tensor
424     if (std::find(output_names.begin(), output_names.end(), tensor->tensor_name()) != output_names.end()) {
425       meta_graph_->outputIndex.push_back(remap_[id]);
426       if (!meta_graph_->subGraph.empty()) {
427         meta_graph_->subGraph[0]->outputIndices.push_back(remap_[id]);
428       }
429     }
430     meta_graph_->allTensors.emplace_back(std::move(tensorT));
431     if (!meta_graph_->subGraph.empty()) {
432       meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1);
433     }
434   }
435   TagQuantizedNodes();  // do another loop to mark QUANT_WEIGHT_NODES
436   auto status = TopologicalSort();
437   if (status != RET_OK) {
438     MS_LOG(ERROR) << "TopologicalSort failed.";
439     return RET_ERROR;
440   }
441 
442   return RET_OK;
443 }
444 
TopologicalSort()445 int TrainExport::TopologicalSort() {
446   MS_ASSERT(meta_graph_ != nullptr);
447   std::vector<std::unique_ptr<schema::CNodeT>> new_nodes;
448   std::vector<size_t> sinked_tensor_idxes;
449   for (auto &subgraph : meta_graph_->subGraph) {
450     std::copy(subgraph->inputIndices.begin(), subgraph->inputIndices.end(), std::back_inserter(sinked_tensor_idxes));
451   }
452   // put all const tensor index into sinked_tensor_idxes
453   for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) {
454     if (meta_graph_->allTensors.at(i)->nodeType == NodeType_ValueNode) {
455       sinked_tensor_idxes.push_back(i);
456     }
457   }
458   auto &old_nodes = meta_graph_->nodes;
459   std::queue<std::unique_ptr<schema::CNodeT>> op_queue;
460   // put all none depend node into queue
461   for (size_t i = 0; i < meta_graph_->subGraph.size(); i++) {
462     std::vector<unsigned int> new_subgraph_node_indices = {};
463     auto subgraph_node_indices = meta_graph_->subGraph[i]->nodeIndices;
464 
465     for (size_t j = 0; j < subgraph_node_indices.size(); j++) {
466       auto &node = old_nodes[subgraph_node_indices[j]];
467       if (IsNodeNonDepend(node, sinked_tensor_idxes)) {
468         sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
469         op_queue.push(std::move(node));
470       }
471     }
472     while (!op_queue.empty()) {
473       auto &node = op_queue.front();
474       auto post_node_idxes = GetOutputNodeIdx(*meta_graph_, *(node.get()));
475       sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
476       for (auto post_node_idx : post_node_idxes) {
477         if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) {
478           auto &post_node = old_nodes.at(post_node_idx);
479           // check if post_node is non-depended
480           if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) {
481             op_queue.push(std::move(post_node));
482           }
483         }
484       }
485       new_nodes.emplace_back(std::move(node));
486       new_subgraph_node_indices.push_back(new_nodes.size() - 1);
487       op_queue.pop();
488     }
489     meta_graph_->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices);
490   }
491   if (new_nodes.size() != old_nodes.size()) {
492     MS_LOG(ERROR) << "Unknown error in TopologicalSort, old_nodes size: " << old_nodes.size()
493                   << ", new_nodes size: " << new_nodes.size();
494     return RET_ERROR;
495   }
496   meta_graph_->nodes.swap(new_nodes);
497   return RET_OK;
498 }
499 
IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> & node,const std::vector<size_t> & sinked_tensor_idxes)500 bool TrainExport::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
501                                   const std::vector<size_t> &sinked_tensor_idxes) {
502   MS_ASSERT(node != nullptr);
503   return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
504                      [&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); });
505 }
506 
ExportInit(const std::string model_name,std::string version)507 int TrainExport::ExportInit(const std::string model_name, std::string version) {
508   meta_graph_ = new (std::nothrow) schema::MetaGraphT();
509   if (meta_graph_ == nullptr) {
510     MS_LOG(ERROR) << "cannot allocate meta_graph";
511     return RET_ERROR;
512   }
513   auto sub_graph = std::make_unique<schema::SubGraphT>();
514   if (sub_graph == nullptr) {
515     MS_LOG(ERROR) << "cannot allocate SubGraphT";
516     return RET_ERROR;
517   }
518   sub_graph->name = model_name + "_subgraph";
519   meta_graph_->subGraph.emplace_back(std::move(sub_graph));
520   meta_graph_->fmkType = kFmkVal;
521   meta_graph_->name = model_name;
522   meta_graph_->version = version;
523   return RET_OK;
524 }
525 
SaveToFile()526 int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
527 
IsInputTensor(const schema::TensorT & t)528 int TrainExport::IsInputTensor(const schema::TensorT &t) {
529   int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
530   return ((t.data.size() == 0) && (total_dims != 0));
531 }
532 
~TrainExport()533 TrainExport::~TrainExport() { delete meta_graph_; }
534 }  // namespace lite
535 }  // namespace mindspore
536