1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <cmath>
18 #include <memory>
19 #include "schema/inner/model_generated.h"
20 #include "mindspore/lite/include/model.h"
21 #include "common/common_test.h"
22 #include "include/lite_session.h"
23 #include "include/context.h"
24 #include "include/model.h"
25 #include "include/errorcode.h"
26 #include "src/common/log_adapter.h"
27 #include "src/lite_session.h"
28 #include "tools/common/storage.h"
29 #include "include/version.h"
30
31 namespace mindspore {
32 class SubGraphTest : public mindspore::CommonTest {
33 public:
34 SubGraphTest() = default;
35 };
36
TEST_F(SubGraphTest,RecursiveSubGraphTest)37 TEST_F(SubGraphTest, RecursiveSubGraphTest) {
38 auto meta_graph = std::make_shared<schema::MetaGraphT>();
39 meta_graph->allTensors.resize(16);
40 { // subgraph-0
41 { // add-0
42 auto add_0 = std::make_unique<schema::CNodeT>();
43 add_0->inputIndex = {0, 1};
44 add_0->outputIndex = {2};
45 add_0->primitive = std::make_unique<schema::PrimitiveT>();
46 add_0->primitive->value.type = schema::PrimitiveType_AddFusion;
47 auto add_0_prim = new schema::AddFusionT;
48 add_0_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
49 add_0->primitive->value.value = add_0_prim;
50 add_0->name = "Add0";
51 auto tensor_0 = std::make_unique<schema::TensorT>();
52 tensor_0->nodeType = lite::NodeType_ValueNode;
53 tensor_0->format = schema::Format_NHWC;
54 tensor_0->dataType = TypeId::kNumberTypeFloat32;
55 tensor_0->dims = {1};
56 auto tensor_1 = std::make_unique<schema::TensorT>();
57 tensor_1->nodeType = lite::NodeType_ValueNode;
58 tensor_1->format = schema::Format_NHWC;
59 tensor_1->dataType = TypeId::kNumberTypeFloat32;
60 tensor_1->dims = {1};
61 tensor_1->data.resize(sizeof(float));
62 auto data1 = reinterpret_cast<float *>(tensor_1->data.data());
63 ASSERT_NE(data1, nullptr);
64 data1[0] = 1;
65 auto tensor_2 = std::make_unique<schema::TensorT>();
66 tensor_2->nodeType = lite::NodeType_Parameter;
67 tensor_2->format = schema::Format_NHWC;
68 tensor_2->dataType = TypeId::kNumberTypeFloat32;
69 meta_graph->nodes.emplace_back(std::move(add_0));
70 meta_graph->allTensors[0] = std::move(tensor_0);
71 meta_graph->allTensors[1] = std::move(tensor_1);
72 meta_graph->allTensors[2] = std::move(tensor_2);
73 }
74 { // add-1
75 auto add_1 = std::make_unique<schema::CNodeT>();
76 add_1->inputIndex = {2, 3};
77 add_1->outputIndex = {4};
78 add_1->primitive = std::make_unique<schema::PrimitiveT>();
79 add_1->primitive->value.type = schema::PrimitiveType_AddFusion;
80 auto add_1_prim = new schema::AddFusionT;
81 add_1_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
82 add_1->primitive->value.value = add_1_prim;
83 add_1->name = "Add1";
84 auto tensor_3 = std::make_unique<schema::TensorT>();
85 tensor_3->nodeType = lite::NodeType_ValueNode;
86 tensor_3->format = schema::Format_NHWC;
87 tensor_3->dataType = TypeId::kNumberTypeFloat32;
88 tensor_3->dims = {1};
89 tensor_3->data.resize(sizeof(float));
90 auto data3 = reinterpret_cast<float *>(tensor_3->data.data());
91 ASSERT_NE(data3, nullptr);
92 data3[0] = 1;
93 auto tensor_4 = std::make_unique<schema::TensorT>();
94 tensor_4->nodeType = lite::NodeType_Parameter;
95 tensor_4->format = schema::Format_NHWC;
96 tensor_4->dataType = TypeId::kNumberTypeFloat32;
97 meta_graph->nodes.emplace_back(std::move(add_1));
98 meta_graph->allTensors[3] = std::move(tensor_3);
99 meta_graph->allTensors[4] = std::move(tensor_4);
100 }
101 { // partial cond
102 auto partial_cond = std::make_unique<schema::CNodeT>();
103 partial_cond->inputIndex = {4};
104 partial_cond->outputIndex = {9};
105 partial_cond->primitive = std::make_unique<schema::PrimitiveT>();
106 partial_cond->primitive->value.type = schema::PrimitiveType_PartialFusion;
107 auto partial_cond_prim = new schema::PartialFusionT;
108 partial_cond_prim->sub_graph_index = 1;
109 partial_cond->primitive->value.value = partial_cond_prim;
110 partial_cond->name = "partial_cond";
111 meta_graph->nodes.emplace_back(std::move(partial_cond));
112 }
113 { // add-5
114 auto add_5 = std::make_unique<schema::CNodeT>();
115 add_5->inputIndex = {9, 13};
116 add_5->outputIndex = {14};
117 add_5->primitive = std::make_unique<schema::PrimitiveT>();
118 add_5->primitive->value.type = schema::PrimitiveType_AddFusion;
119 auto add_5_prim = new schema::AddFusionT;
120 add_5_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
121 add_5->primitive->value.value = add_5_prim;
122 add_5->name = "Add5";
123 auto tensor_13 = std::make_unique<schema::TensorT>();
124 tensor_13->nodeType = lite::NodeType_ValueNode;
125 tensor_13->format = schema::Format_NHWC;
126 tensor_13->dataType = TypeId::kNumberTypeFloat32;
127 tensor_13->dims = {1};
128 tensor_13->data.resize(sizeof(float));
129 auto data13 = reinterpret_cast<float *>(tensor_13->data.data());
130 ASSERT_NE(data13, nullptr);
131 data13[0] = 1;
132 auto tensor_14 = std::make_unique<schema::TensorT>();
133 tensor_14->nodeType = lite::NodeType_Parameter;
134 tensor_14->format = schema::Format_NHWC;
135 tensor_14->dataType = TypeId::kNumberTypeFloat32;
136 meta_graph->nodes.emplace_back(std::move(add_5));
137 meta_graph->allTensors[13] = std::move(tensor_13);
138 meta_graph->allTensors[14] = std::move(tensor_14);
139 }
140 auto sub_graph_0 = std::make_unique<schema::SubGraphT>();
141 sub_graph_0->name = "main_graph";
142 sub_graph_0->inputIndices = {0};
143 sub_graph_0->outputIndices = {14};
144 sub_graph_0->nodeIndices = {0, 1, 2, 3};
145 sub_graph_0->tensorIndices = {0, 1, 2, 3, 4, 9, 13, 14};
146 meta_graph->subGraph.emplace_back(std::move(sub_graph_0));
147 }
148 { // subgraph-1
149 { // add-2
150 auto add_2 = std::make_unique<schema::CNodeT>();
151 add_2->inputIndex = {4, 5};
152 add_2->outputIndex = {6};
153 add_2->primitive = std::make_unique<schema::PrimitiveT>();
154 add_2->primitive->value.type = schema::PrimitiveType_AddFusion;
155 auto add_2_prim = new schema::AddFusionT;
156 add_2_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
157 add_2->primitive->value.value = add_2_prim;
158 add_2->name = "Add2";
159 auto tensor_5 = std::make_unique<schema::TensorT>();
160 tensor_5->nodeType = lite::NodeType_ValueNode;
161 tensor_5->format = schema::Format_NHWC;
162 tensor_5->dataType = TypeId::kNumberTypeFloat32;
163 tensor_5->dims = {1};
164 tensor_5->data.resize(sizeof(float));
165 auto data5 = reinterpret_cast<float *>(tensor_5->data.data());
166 ASSERT_NE(data5, nullptr);
167 data5[0] = 1;
168 auto tensor_6 = std::make_unique<schema::TensorT>();
169 tensor_6->nodeType = lite::NodeType_Parameter;
170 tensor_6->format = schema::Format_NHWC;
171 tensor_6->dataType = TypeId::kNumberTypeFloat32;
172 meta_graph->nodes.emplace_back(std::move(add_2));
173 meta_graph->allTensors[5] = std::move(tensor_5);
174 meta_graph->allTensors[6] = std::move(tensor_6);
175 }
176 { // less
177 auto less = std::make_unique<schema::CNodeT>();
178 less->inputIndex = {6, 15};
179 less->outputIndex = {7};
180 less->primitive = std::make_unique<schema::PrimitiveT>();
181 less->primitive->value.type = schema::PrimitiveType_Less;
182 auto less_prim = new schema::LessT;
183 less->primitive->value.value = less_prim;
184 less->name = "less";
185 auto tensor_15 = std::make_unique<schema::TensorT>();
186 tensor_15->nodeType = lite::NodeType_ValueNode;
187 tensor_15->format = schema::Format_NHWC;
188 tensor_15->dataType = TypeId::kNumberTypeFloat32;
189 tensor_15->dims = {1};
190 tensor_15->data.resize(sizeof(float));
191 auto data15 = reinterpret_cast<float *>(tensor_15->data.data());
192 ASSERT_NE(data15, nullptr);
193 data15[0] = 1;
194 auto tensor_7 = std::make_unique<schema::TensorT>();
195 tensor_7->nodeType = lite::NodeType_Parameter;
196 tensor_7->format = schema::Format_NHWC;
197 tensor_7->dataType = TypeId::kNumberTypeFloat32;
198 meta_graph->nodes.emplace_back(std::move(less));
199 meta_graph->allTensors[7] = std::move(tensor_7);
200 meta_graph->allTensors[15] = std::move(tensor_15);
201 }
202 { // switch
203 auto switchop = std::make_unique<schema::CNodeT>();
204 switchop->inputIndex = {7, 4};
205 switchop->outputIndex = {8, 9};
206 switchop->primitive = std::make_unique<schema::PrimitiveT>();
207 switchop->primitive->value.type = schema::PrimitiveType_Switch;
208 auto switch_prim = new schema::SwitchT;
209 switchop->primitive->value.value = switch_prim;
210 switchop->name = "switch";
211 auto tensor_8 = std::make_unique<schema::TensorT>();
212 tensor_8->nodeType = lite::NodeType_Parameter;
213 tensor_8->format = schema::Format_NHWC;
214 tensor_8->dataType = TypeId::kNumberTypeFloat32;
215 auto tensor_9 = std::make_unique<schema::TensorT>();
216 tensor_9->nodeType = lite::NodeType_Parameter;
217 tensor_9->format = schema::Format_NHWC;
218 tensor_9->dataType = TypeId::kNumberTypeFloat32;
219 meta_graph->nodes.emplace_back(std::move(switchop));
220 meta_graph->allTensors[8] = std::move(tensor_8);
221 meta_graph->allTensors[9] = std::move(tensor_9);
222 }
223 { // partial body
224 auto partial_body = std::make_unique<schema::CNodeT>();
225 partial_body->inputIndex = {8};
226 partial_body->outputIndex = {4};
227 partial_body->primitive = std::make_unique<schema::PrimitiveT>();
228 partial_body->primitive->value.type = schema::PrimitiveType_PartialFusion;
229 auto partial_body_prim = new schema::PartialFusionT;
230 partial_body_prim->sub_graph_index = 2;
231 partial_body->primitive->value.value = partial_body_prim;
232 partial_body->name = "partial_body";
233 meta_graph->nodes.emplace_back(std::move(partial_body));
234 }
235 auto sub_graph_1 = std::make_unique<schema::SubGraphT>();
236 sub_graph_1->name = "while_cond";
237 sub_graph_1->inputIndices = {4};
238 sub_graph_1->outputIndices = {9};
239 sub_graph_1->nodeIndices = {4, 5, 6, 7};
240 sub_graph_1->tensorIndices = {4, 5, 6, 7, 8, 9, 15};
241 meta_graph->subGraph.emplace_back(std::move(sub_graph_1));
242 }
243 { // subgraph-2
244 { // add-3
245 auto add_3 = std::make_unique<schema::CNodeT>();
246 add_3->inputIndex = {8, 10};
247 add_3->outputIndex = {11};
248 add_3->primitive = std::make_unique<schema::PrimitiveT>();
249 add_3->primitive->value.type = schema::PrimitiveType_AddFusion;
250 auto add_3_prim = new schema::AddFusionT;
251 add_3_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
252 add_3->primitive->value.value = add_3_prim;
253 add_3->name = "Add3";
254 auto tensor_10 = std::make_unique<schema::TensorT>();
255 tensor_10->nodeType = lite::NodeType_ValueNode;
256 tensor_10->format = schema::Format_NHWC;
257 tensor_10->dataType = TypeId::kNumberTypeFloat32;
258 tensor_10->dims = {1};
259 tensor_10->data.resize(sizeof(float));
260 auto data10 = reinterpret_cast<float *>(tensor_10->data.data());
261 ASSERT_NE(data10, nullptr);
262 data10[0] = 1;
263 auto tensor_11 = std::make_unique<schema::TensorT>();
264 tensor_11->nodeType = lite::NodeType_Parameter;
265 tensor_11->format = schema::Format_NHWC;
266 tensor_11->dataType = TypeId::kNumberTypeFloat32;
267 meta_graph->nodes.emplace_back(std::move(add_3));
268 meta_graph->allTensors[10] = std::move(tensor_10);
269 meta_graph->allTensors[11] = std::move(tensor_11);
270 }
271 { // add-4
272 auto add_4 = std::make_unique<schema::CNodeT>();
273 add_4->inputIndex = {11, 12};
274 add_4->outputIndex = {4};
275 add_4->primitive = std::make_unique<schema::PrimitiveT>();
276 add_4->primitive->value.type = schema::PrimitiveType_AddFusion;
277 auto add_4_prim = new schema::AddFusionT;
278 add_4_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
279 add_4->primitive->value.value = add_4_prim;
280 add_4->name = "Add4";
281 auto tensor_12 = std::make_unique<schema::TensorT>();
282 tensor_12->nodeType = lite::NodeType_ValueNode;
283 tensor_12->format = schema::Format_NHWC;
284 tensor_12->dataType = TypeId::kNumberTypeFloat32;
285 tensor_12->dims = {1};
286 tensor_12->data.resize(sizeof(float));
287 auto data12 = reinterpret_cast<float *>(tensor_12->data.data());
288 ASSERT_NE(data12, nullptr);
289 data12[0] = 1;
290 meta_graph->nodes.emplace_back(std::move(add_4));
291 meta_graph->allTensors[12] = std::move(tensor_12);
292 }
293 { // partial cond
294 auto partial_cond = std::make_unique<schema::CNodeT>();
295 partial_cond->inputIndex = {4};
296 partial_cond->outputIndex = {9};
297 partial_cond->primitive = std::make_unique<schema::PrimitiveT>();
298 partial_cond->primitive->value.type = schema::PrimitiveType_PartialFusion;
299 auto partial_cond_prim = new schema::PartialFusionT;
300 partial_cond_prim->sub_graph_index = 1;
301 partial_cond->primitive->value.value = partial_cond_prim;
302 partial_cond->name = "partial_cond1";
303 meta_graph->nodes.emplace_back(std::move(partial_cond));
304 }
305 auto sub_graph_2 = std::make_unique<schema::SubGraphT>();
306 sub_graph_2->name = "while_body";
307 sub_graph_2->inputIndices = {8};
308 sub_graph_2->outputIndices = {9};
309 sub_graph_2->nodeIndices = {8, 9, 10};
310 sub_graph_2->tensorIndices = {8, 10, 11, 12, 4, 9};
311 meta_graph->subGraph.emplace_back(std::move(sub_graph_2));
312 }
313 meta_graph->name = "graph";
314 meta_graph->version = lite::Version();
315 // -----------------------------------------------------------------------
316 lite::Storage::Save(*meta_graph,
317 "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph");
318 // -----------------------------------------------------------------------
319 size_t size = 0;
320 char *graph_buf = lite::ReadFile(
321 "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph.ms", &size);
322 ASSERT_NE(graph_buf, nullptr);
323
324 auto model = std::shared_ptr<lite::Model>(lite::Model::Import(graph_buf, size));
325 ASSERT_NE(model, nullptr);
326 delete[](graph_buf);
327 lite::Context context;
328 auto &cpu_device_ctx = context.device_list_[0];
329 cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU;
330 context.thread_num_ = 2;
331 auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context));
332 ASSERT_NE(session, nullptr);
333 auto ret = session->CompileGraph(model.get());
334 ASSERT_EQ(ret, lite::RET_OK);
335 auto inputs = session->GetInputs();
336 for (auto *input : inputs) {
337 (void)input->MutableData();
338 }
339 ret = session->RunGraph();
340 ASSERT_EQ(ret, lite::RET_OK);
341 }
342 } // namespace mindspore
343