1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <cmath>
18 #include <memory>
19 #include "schema/inner/model_generated.h"
20 #include "common/common_test.h"
21 #include "include/errorcode.h"
22 #include "src/common/log_adapter.h"
23 #include "src/common/file_utils.h"
24 #include "src/litert/lite_session.h"
25 #include "tools/common/meta_graph_serializer.h"
26
27 namespace mindspore {
28 class SubGraphTest : public mindspore::CommonTest {
29 public:
30 SubGraphTest() = default;
31 };
32
TEST_F(SubGraphTest,RecursiveSubGraphTest)33 TEST_F(SubGraphTest, RecursiveSubGraphTest) {
34 auto meta_graph = std::make_shared<schema::MetaGraphT>();
35 meta_graph->allTensors.resize(16);
36 { // subgraph-0
37 { // add-0
38 auto add_0 = std::make_unique<schema::CNodeT>();
39 add_0->inputIndex = {0, 1};
40 add_0->outputIndex = {2};
41 add_0->primitive = std::make_unique<schema::PrimitiveT>();
42 add_0->primitive->value.type = schema::PrimitiveType_AddFusion;
43 auto add_0_prim = new schema::AddFusionT;
44 add_0_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
45 add_0->primitive->value.value = add_0_prim;
46 add_0->name = "Add0";
47 auto tensor_0 = std::make_unique<schema::TensorT>();
48 tensor_0->nodeType = lite::NodeType_Parameter;
49 tensor_0->format = schema::Format_NHWC;
50 tensor_0->dataType = TypeId::kNumberTypeFloat32;
51 tensor_0->dims = {1};
52 auto tensor_1 = std::make_unique<schema::TensorT>();
53 tensor_1->nodeType = lite::NodeType_Parameter;
54 tensor_1->format = schema::Format_NHWC;
55 tensor_1->dataType = TypeId::kNumberTypeFloat32;
56 tensor_1->dims = {1};
57 tensor_1->data.resize(sizeof(float));
58 auto data1 = reinterpret_cast<float *>(tensor_1->data.data());
59 ASSERT_NE(data1, nullptr);
60 data1[0] = 1;
61 auto tensor_2 = std::make_unique<schema::TensorT>();
62 tensor_2->nodeType = lite::NodeType_Parameter;
63 tensor_2->format = schema::Format_NHWC;
64 tensor_2->dataType = TypeId::kNumberTypeFloat32;
65 meta_graph->nodes.emplace_back(std::move(add_0));
66 meta_graph->allTensors[0] = std::move(tensor_0);
67 meta_graph->allTensors[1] = std::move(tensor_1);
68 meta_graph->allTensors[2] = std::move(tensor_2);
69 }
70 { // add-1
71 auto add_1 = std::make_unique<schema::CNodeT>();
72 add_1->inputIndex = {2, 3};
73 add_1->outputIndex = {4};
74 add_1->primitive = std::make_unique<schema::PrimitiveT>();
75 add_1->primitive->value.type = schema::PrimitiveType_AddFusion;
76 auto add_1_prim = new schema::AddFusionT;
77 add_1_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
78 add_1->primitive->value.value = add_1_prim;
79 add_1->name = "Add1";
80 auto tensor_3 = std::make_unique<schema::TensorT>();
81 tensor_3->nodeType = lite::NodeType_ValueNode;
82 tensor_3->format = schema::Format_NHWC;
83 tensor_3->dataType = TypeId::kNumberTypeFloat32;
84 tensor_3->dims = {1};
85 tensor_3->data.resize(sizeof(float));
86 auto data3 = reinterpret_cast<float *>(tensor_3->data.data());
87 ASSERT_NE(data3, nullptr);
88 data3[0] = 1;
89 auto tensor_4 = std::make_unique<schema::TensorT>();
90 tensor_4->nodeType = lite::NodeType_Parameter;
91 tensor_4->format = schema::Format_NHWC;
92 tensor_4->dataType = TypeId::kNumberTypeFloat32;
93 meta_graph->nodes.emplace_back(std::move(add_1));
94 meta_graph->allTensors[3] = std::move(tensor_3);
95 meta_graph->allTensors[4] = std::move(tensor_4);
96 }
97 { // partial cond
98 auto partial_cond = std::make_unique<schema::CNodeT>();
99 partial_cond->inputIndex = {4};
100 partial_cond->outputIndex = {9};
101 partial_cond->primitive = std::make_unique<schema::PrimitiveT>();
102 partial_cond->primitive->value.type = schema::PrimitiveType_PartialFusion;
103 auto partial_cond_prim = new schema::PartialFusionT;
104 partial_cond_prim->sub_graph_index = 1;
105 partial_cond->primitive->value.value = partial_cond_prim;
106 partial_cond->name = "partial_cond";
107 meta_graph->nodes.emplace_back(std::move(partial_cond));
108 }
109 { // add-5
110 auto add_5 = std::make_unique<schema::CNodeT>();
111 add_5->inputIndex = {9, 13};
112 add_5->outputIndex = {14};
113 add_5->primitive = std::make_unique<schema::PrimitiveT>();
114 add_5->primitive->value.type = schema::PrimitiveType_AddFusion;
115 auto add_5_prim = new schema::AddFusionT;
116 add_5_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
117 add_5->primitive->value.value = add_5_prim;
118 add_5->name = "Add5";
119 auto tensor_13 = std::make_unique<schema::TensorT>();
120 tensor_13->nodeType = lite::NodeType_ValueNode;
121 tensor_13->format = schema::Format_NHWC;
122 tensor_13->dataType = TypeId::kNumberTypeFloat32;
123 tensor_13->dims = {1};
124 tensor_13->data.resize(sizeof(float));
125 auto data13 = reinterpret_cast<float *>(tensor_13->data.data());
126 ASSERT_NE(data13, nullptr);
127 data13[0] = 1;
128 auto tensor_14 = std::make_unique<schema::TensorT>();
129 tensor_14->nodeType = lite::NodeType_Parameter;
130 tensor_14->format = schema::Format_NHWC;
131 tensor_14->dataType = TypeId::kNumberTypeFloat32;
132 meta_graph->nodes.emplace_back(std::move(add_5));
133 meta_graph->allTensors[13] = std::move(tensor_13);
134 meta_graph->allTensors[14] = std::move(tensor_14);
135 }
136 auto sub_graph_0 = std::make_unique<schema::SubGraphT>();
137 sub_graph_0->name = "main_graph";
138 sub_graph_0->inputIndices = {0};
139 sub_graph_0->outputIndices = {14};
140 sub_graph_0->nodeIndices = {0, 1, 2, 3};
141 sub_graph_0->tensorIndices = {0, 1, 2, 3, 4, 9, 13, 14};
142 meta_graph->subGraph.emplace_back(std::move(sub_graph_0));
143 }
144 { // subgraph-1
145 { // add-2
146 auto add_2 = std::make_unique<schema::CNodeT>();
147 add_2->inputIndex = {4, 5};
148 add_2->outputIndex = {6};
149 add_2->primitive = std::make_unique<schema::PrimitiveT>();
150 add_2->primitive->value.type = schema::PrimitiveType_AddFusion;
151 auto add_2_prim = new schema::AddFusionT;
152 add_2_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
153 add_2->primitive->value.value = add_2_prim;
154 add_2->name = "Add2";
155 auto tensor_5 = std::make_unique<schema::TensorT>();
156 tensor_5->nodeType = lite::NodeType_ValueNode;
157 tensor_5->format = schema::Format_NHWC;
158 tensor_5->dataType = TypeId::kNumberTypeFloat32;
159 tensor_5->dims = {1};
160 tensor_5->data.resize(sizeof(float));
161 auto data5 = reinterpret_cast<float *>(tensor_5->data.data());
162 ASSERT_NE(data5, nullptr);
163 data5[0] = 1;
164 auto tensor_6 = std::make_unique<schema::TensorT>();
165 tensor_6->nodeType = lite::NodeType_Parameter;
166 tensor_6->format = schema::Format_NHWC;
167 tensor_6->dataType = TypeId::kNumberTypeFloat32;
168 meta_graph->nodes.emplace_back(std::move(add_2));
169 meta_graph->allTensors[5] = std::move(tensor_5);
170 meta_graph->allTensors[6] = std::move(tensor_6);
171 }
172 { // less
173 auto less = std::make_unique<schema::CNodeT>();
174 less->inputIndex = {6, 15};
175 less->outputIndex = {7};
176 less->primitive = std::make_unique<schema::PrimitiveT>();
177 less->primitive->value.type = schema::PrimitiveType_Less;
178 auto less_prim = new schema::LessT;
179 less->primitive->value.value = less_prim;
180 less->name = "less";
181 auto tensor_15 = std::make_unique<schema::TensorT>();
182 tensor_15->nodeType = lite::NodeType_ValueNode;
183 tensor_15->format = schema::Format_NHWC;
184 tensor_15->dataType = TypeId::kNumberTypeFloat32;
185 tensor_15->dims = {1};
186 tensor_15->data.resize(sizeof(float));
187 auto data15 = reinterpret_cast<float *>(tensor_15->data.data());
188 ASSERT_NE(data15, nullptr);
189 data15[0] = 1;
190 auto tensor_7 = std::make_unique<schema::TensorT>();
191 tensor_7->nodeType = lite::NodeType_Parameter;
192 tensor_7->format = schema::Format_NHWC;
193 tensor_7->dataType = TypeId::kNumberTypeFloat32;
194 meta_graph->nodes.emplace_back(std::move(less));
195 meta_graph->allTensors[7] = std::move(tensor_7);
196 meta_graph->allTensors[15] = std::move(tensor_15);
197 }
198 { // switch
199 auto switchop = std::make_unique<schema::CNodeT>();
200 switchop->inputIndex = {7, 4};
201 switchop->outputIndex = {8, 9};
202 switchop->primitive = std::make_unique<schema::PrimitiveT>();
203 switchop->primitive->value.type = schema::PrimitiveType_Switch;
204 auto switch_prim = new schema::SwitchT;
205 switchop->primitive->value.value = switch_prim;
206 switchop->name = "switch";
207 auto tensor_8 = std::make_unique<schema::TensorT>();
208 tensor_8->nodeType = lite::NodeType_Parameter;
209 tensor_8->format = schema::Format_NHWC;
210 tensor_8->dataType = TypeId::kNumberTypeFloat32;
211 auto tensor_9 = std::make_unique<schema::TensorT>();
212 tensor_9->nodeType = lite::NodeType_Parameter;
213 tensor_9->format = schema::Format_NHWC;
214 tensor_9->dataType = TypeId::kNumberTypeFloat32;
215 meta_graph->nodes.emplace_back(std::move(switchop));
216 meta_graph->allTensors[8] = std::move(tensor_8);
217 meta_graph->allTensors[9] = std::move(tensor_9);
218 }
219 { // partial body
220 auto partial_body = std::make_unique<schema::CNodeT>();
221 partial_body->inputIndex = {8};
222 partial_body->outputIndex = {4};
223 partial_body->primitive = std::make_unique<schema::PrimitiveT>();
224 partial_body->primitive->value.type = schema::PrimitiveType_PartialFusion;
225 auto partial_body_prim = new schema::PartialFusionT;
226 partial_body_prim->sub_graph_index = 2;
227 partial_body->primitive->value.value = partial_body_prim;
228 partial_body->name = "partial_body";
229 meta_graph->nodes.emplace_back(std::move(partial_body));
230 }
231 auto sub_graph_1 = std::make_unique<schema::SubGraphT>();
232 sub_graph_1->name = "while_cond";
233 sub_graph_1->inputIndices = {4};
234 sub_graph_1->outputIndices = {9};
235 sub_graph_1->nodeIndices = {4, 5, 6, 7};
236 sub_graph_1->tensorIndices = {4, 5, 6, 7, 8, 9, 15};
237 meta_graph->subGraph.emplace_back(std::move(sub_graph_1));
238 }
239 { // subgraph-2
240 { // add-3
241 auto add_3 = std::make_unique<schema::CNodeT>();
242 add_3->inputIndex = {8, 10};
243 add_3->outputIndex = {11};
244 add_3->primitive = std::make_unique<schema::PrimitiveT>();
245 add_3->primitive->value.type = schema::PrimitiveType_AddFusion;
246 auto add_3_prim = new schema::AddFusionT;
247 add_3_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
248 add_3->primitive->value.value = add_3_prim;
249 add_3->name = "Add3";
250 auto tensor_10 = std::make_unique<schema::TensorT>();
251 tensor_10->nodeType = lite::NodeType_ValueNode;
252 tensor_10->format = schema::Format_NHWC;
253 tensor_10->dataType = TypeId::kNumberTypeFloat32;
254 tensor_10->dims = {1};
255 tensor_10->data.resize(sizeof(float));
256 auto data10 = reinterpret_cast<float *>(tensor_10->data.data());
257 ASSERT_NE(data10, nullptr);
258 data10[0] = 1;
259 auto tensor_11 = std::make_unique<schema::TensorT>();
260 tensor_11->nodeType = lite::NodeType_Parameter;
261 tensor_11->format = schema::Format_NHWC;
262 tensor_11->dataType = TypeId::kNumberTypeFloat32;
263 meta_graph->nodes.emplace_back(std::move(add_3));
264 meta_graph->allTensors[10] = std::move(tensor_10);
265 meta_graph->allTensors[11] = std::move(tensor_11);
266 }
267 { // add-4
268 auto add_4 = std::make_unique<schema::CNodeT>();
269 add_4->inputIndex = {11, 12};
270 add_4->outputIndex = {4};
271 add_4->primitive = std::make_unique<schema::PrimitiveT>();
272 add_4->primitive->value.type = schema::PrimitiveType_AddFusion;
273 auto add_4_prim = new schema::AddFusionT;
274 add_4_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
275 add_4->primitive->value.value = add_4_prim;
276 add_4->name = "Add4";
277 auto tensor_12 = std::make_unique<schema::TensorT>();
278 tensor_12->nodeType = lite::NodeType_ValueNode;
279 tensor_12->format = schema::Format_NHWC;
280 tensor_12->dataType = TypeId::kNumberTypeFloat32;
281 tensor_12->dims = {1};
282 tensor_12->data.resize(sizeof(float));
283 auto data12 = reinterpret_cast<float *>(tensor_12->data.data());
284 ASSERT_NE(data12, nullptr);
285 data12[0] = 1;
286 meta_graph->nodes.emplace_back(std::move(add_4));
287 meta_graph->allTensors[12] = std::move(tensor_12);
288 }
289 { // partial cond
290 auto partial_cond = std::make_unique<schema::CNodeT>();
291 partial_cond->inputIndex = {4};
292 partial_cond->outputIndex = {9};
293 partial_cond->primitive = std::make_unique<schema::PrimitiveT>();
294 partial_cond->primitive->value.type = schema::PrimitiveType_PartialFusion;
295 auto partial_cond_prim = new schema::PartialFusionT;
296 partial_cond_prim->sub_graph_index = 1;
297 partial_cond->primitive->value.value = partial_cond_prim;
298 partial_cond->name = "partial_cond1";
299 meta_graph->nodes.emplace_back(std::move(partial_cond));
300 }
301 auto sub_graph_2 = std::make_unique<schema::SubGraphT>();
302 sub_graph_2->name = "while_body";
303 sub_graph_2->inputIndices = {8};
304 sub_graph_2->outputIndices = {9};
305 sub_graph_2->nodeIndices = {8, 9, 10};
306 sub_graph_2->tensorIndices = {8, 10, 11, 12, 4, 9};
307 meta_graph->subGraph.emplace_back(std::move(sub_graph_2));
308 }
309 meta_graph->name = "graph";
310 meta_graph->version = Version();
311 // -----------------------------------------------------------------------
312 lite::MetaGraphSerializer::Save(
313 *meta_graph, "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph");
314 // -----------------------------------------------------------------------
315 size_t size = 0;
316 char *graph_buf = lite::ReadFile(
317 "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph.ms", &size);
318 ASSERT_NE(graph_buf, nullptr);
319
320 auto model = std::shared_ptr<lite::Model>(lite::Model::Import(graph_buf, size));
321 ASSERT_NE(model, nullptr);
322 delete[](graph_buf);
323 auto context = std::make_shared<mindspore::lite::InnerContext>();
324 auto &cpu_device_ctx = context->device_list_[0];
325 cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU;
326 context->thread_num_ = 2;
327 auto session = std::shared_ptr<lite::LiteSession>(lite::LiteSession::CreateSession(context));
328 ASSERT_NE(session, nullptr);
329 auto ret = session->CompileGraph(model.get());
330 ASSERT_EQ(ret, lite::RET_OK);
331 auto inputs = session->GetInputs();
332 for (auto *input : inputs) {
333 (void)input->MutableData();
334 }
335 ret = session->RunGraph();
336 ASSERT_EQ(ret, lite::RET_OK);
337 }
338 } // namespace mindspore
339