1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "coder/train.h"
18 #include <memory>
19 #include <set>
20 #include <array>
21 #include <queue>
22 #include <string>
23 #include <vector>
24 #include <algorithm>
25 #include "schema/ops_generated.h"
26 #include "src/common/prim_util.h"
27
28 namespace mindspore::lite::micro {
FindInferenceOpcoders(OperatorCoder * edge)29 std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge) {
30 std::set<OperatorCoder *> subgraph;
31 std::queue<OperatorCoder *> to_visit;
32 to_visit.push(edge);
33 while (!to_visit.empty()) {
34 size_t size = to_visit.size();
35 for (size_t i = 0; i < size; ++i) {
36 OperatorCoder *curr = to_visit.front();
37 to_visit.pop();
38 if (subgraph.find(curr) != subgraph.end()) {
39 continue;
40 }
41 subgraph.insert(curr);
42 for (const auto &op : curr->input_ops()) {
43 to_visit.push(op);
44 }
45 }
46 }
47 auto item = subgraph.find(edge);
48 if (item == subgraph.end()) {
49 MS_LOG(ERROR) << "failed to find the edge in the subgraph";
50 return subgraph;
51 }
52 // erase edge operator coder from subgraph
53 subgraph.erase(item);
54 return subgraph;
55 }
56
TransformGraphForTrain(CoderContext * context,const std::vector<std::unique_ptr<OperatorCoder>> & op_coders,int schema_version)57 int Train::TransformGraphForTrain(CoderContext *context, const std::vector<std::unique_ptr<OperatorCoder>> &op_coders,
58 int schema_version) {
59 if (context == nullptr) {
60 MS_LOG(INFO) << "input context invalid";
61 return RET_ERROR;
62 }
63 const std::array<int, 6> loss_types = {schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
64 schema::PrimitiveType_BinaryCrossEntropy,
65 schema::PrimitiveType_SmoothL1Loss,
66 schema::PrimitiveType_SmoothL1LossGrad,
67 schema::PrimitiveType_SigmoidCrossEntropyWithLogits,
68 schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad};
69 OperatorCoder *loss_op = nullptr;
70 for (const auto &opcoder : op_coders) {
71 const Model::Node *node = opcoder->node();
72 int primitive_type = GetPrimitiveType(node->primitive_, schema_version);
73 auto item = std::find(loss_types.begin(), loss_types.end(), primitive_type);
74 if (item != loss_types.end()) {
75 loss_op = opcoder.get();
76 break;
77 }
78 }
79 MS_CHECK_PTR(loss_op);
80 size_t op_num = op_coders.size();
81 std::vector<std::string> code_blocks = context->code_blocks();
82 if (op_num != code_blocks.size()) {
83 MS_LOG(INFO) << "the number of code blocks and op coders is not equal";
84 return RET_ERROR;
85 }
86 std::set<OperatorCoder *> inference_ops = FindInferenceOpcoders(loss_op);
87 std::vector<std::string> inferences_blocks;
88 std::vector<std::string> train_blocks;
89 for (size_t i = 0; i < op_num; ++i) {
90 auto &opcoder = op_coders.at(i);
91 std::string block = code_blocks.at(i);
92 if (inference_ops.find(opcoder.get()) != inference_ops.end()) {
93 inferences_blocks.push_back(block);
94 }
95 train_blocks.push_back(block);
96 }
97 context->set_inference_blocks(inferences_blocks);
98 context->set_train_blocks(train_blocks);
99 return RET_OK;
100 }
101 } // namespace mindspore::lite::micro
102