• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <cmath>
18 #include <memory>
19 #include "schema/inner/model_generated.h"
20 #include "common/common_test.h"
21 #include "include/errorcode.h"
22 #include "src/common/log_adapter.h"
23 #include "src/common/file_utils.h"
24 #include "src/litert/lite_session.h"
25 #include "tools/common/meta_graph_serializer.h"
26 
27 namespace mindspore {
28 class SubGraphTest : public mindspore::CommonTest {
29  public:
30   SubGraphTest() = default;
31 };
32 
TEST_F(SubGraphTest,RecursiveSubGraphTest)33 TEST_F(SubGraphTest, RecursiveSubGraphTest) {
34   auto meta_graph = std::make_shared<schema::MetaGraphT>();
35   meta_graph->allTensors.resize(16);
36   {    // subgraph-0
37     {  // add-0
38       auto add_0 = std::make_unique<schema::CNodeT>();
39       add_0->inputIndex = {0, 1};
40       add_0->outputIndex = {2};
41       add_0->primitive = std::make_unique<schema::PrimitiveT>();
42       add_0->primitive->value.type = schema::PrimitiveType_AddFusion;
43       auto add_0_prim = new schema::AddFusionT;
44       add_0_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
45       add_0->primitive->value.value = add_0_prim;
46       add_0->name = "Add0";
47       auto tensor_0 = std::make_unique<schema::TensorT>();
48       tensor_0->nodeType = lite::NodeType_Parameter;
49       tensor_0->format = schema::Format_NHWC;
50       tensor_0->dataType = TypeId::kNumberTypeFloat32;
51       tensor_0->dims = {1};
52       auto tensor_1 = std::make_unique<schema::TensorT>();
53       tensor_1->nodeType = lite::NodeType_Parameter;
54       tensor_1->format = schema::Format_NHWC;
55       tensor_1->dataType = TypeId::kNumberTypeFloat32;
56       tensor_1->dims = {1};
57       tensor_1->data.resize(sizeof(float));
58       auto data1 = reinterpret_cast<float *>(tensor_1->data.data());
59       ASSERT_NE(data1, nullptr);
60       data1[0] = 1;
61       auto tensor_2 = std::make_unique<schema::TensorT>();
62       tensor_2->nodeType = lite::NodeType_Parameter;
63       tensor_2->format = schema::Format_NHWC;
64       tensor_2->dataType = TypeId::kNumberTypeFloat32;
65       meta_graph->nodes.emplace_back(std::move(add_0));
66       meta_graph->allTensors[0] = std::move(tensor_0);
67       meta_graph->allTensors[1] = std::move(tensor_1);
68       meta_graph->allTensors[2] = std::move(tensor_2);
69     }
70     {  // add-1
71       auto add_1 = std::make_unique<schema::CNodeT>();
72       add_1->inputIndex = {2, 3};
73       add_1->outputIndex = {4};
74       add_1->primitive = std::make_unique<schema::PrimitiveT>();
75       add_1->primitive->value.type = schema::PrimitiveType_AddFusion;
76       auto add_1_prim = new schema::AddFusionT;
77       add_1_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
78       add_1->primitive->value.value = add_1_prim;
79       add_1->name = "Add1";
80       auto tensor_3 = std::make_unique<schema::TensorT>();
81       tensor_3->nodeType = lite::NodeType_ValueNode;
82       tensor_3->format = schema::Format_NHWC;
83       tensor_3->dataType = TypeId::kNumberTypeFloat32;
84       tensor_3->dims = {1};
85       tensor_3->data.resize(sizeof(float));
86       auto data3 = reinterpret_cast<float *>(tensor_3->data.data());
87       ASSERT_NE(data3, nullptr);
88       data3[0] = 1;
89       auto tensor_4 = std::make_unique<schema::TensorT>();
90       tensor_4->nodeType = lite::NodeType_Parameter;
91       tensor_4->format = schema::Format_NHWC;
92       tensor_4->dataType = TypeId::kNumberTypeFloat32;
93       meta_graph->nodes.emplace_back(std::move(add_1));
94       meta_graph->allTensors[3] = std::move(tensor_3);
95       meta_graph->allTensors[4] = std::move(tensor_4);
96     }
97     {  // partial cond
98       auto partial_cond = std::make_unique<schema::CNodeT>();
99       partial_cond->inputIndex = {4};
100       partial_cond->outputIndex = {9};
101       partial_cond->primitive = std::make_unique<schema::PrimitiveT>();
102       partial_cond->primitive->value.type = schema::PrimitiveType_PartialFusion;
103       auto partial_cond_prim = new schema::PartialFusionT;
104       partial_cond_prim->sub_graph_index = 1;
105       partial_cond->primitive->value.value = partial_cond_prim;
106       partial_cond->name = "partial_cond";
107       meta_graph->nodes.emplace_back(std::move(partial_cond));
108     }
109     {  // add-5
110       auto add_5 = std::make_unique<schema::CNodeT>();
111       add_5->inputIndex = {9, 13};
112       add_5->outputIndex = {14};
113       add_5->primitive = std::make_unique<schema::PrimitiveT>();
114       add_5->primitive->value.type = schema::PrimitiveType_AddFusion;
115       auto add_5_prim = new schema::AddFusionT;
116       add_5_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
117       add_5->primitive->value.value = add_5_prim;
118       add_5->name = "Add5";
119       auto tensor_13 = std::make_unique<schema::TensorT>();
120       tensor_13->nodeType = lite::NodeType_ValueNode;
121       tensor_13->format = schema::Format_NHWC;
122       tensor_13->dataType = TypeId::kNumberTypeFloat32;
123       tensor_13->dims = {1};
124       tensor_13->data.resize(sizeof(float));
125       auto data13 = reinterpret_cast<float *>(tensor_13->data.data());
126       ASSERT_NE(data13, nullptr);
127       data13[0] = 1;
128       auto tensor_14 = std::make_unique<schema::TensorT>();
129       tensor_14->nodeType = lite::NodeType_Parameter;
130       tensor_14->format = schema::Format_NHWC;
131       tensor_14->dataType = TypeId::kNumberTypeFloat32;
132       meta_graph->nodes.emplace_back(std::move(add_5));
133       meta_graph->allTensors[13] = std::move(tensor_13);
134       meta_graph->allTensors[14] = std::move(tensor_14);
135     }
136     auto sub_graph_0 = std::make_unique<schema::SubGraphT>();
137     sub_graph_0->name = "main_graph";
138     sub_graph_0->inputIndices = {0};
139     sub_graph_0->outputIndices = {14};
140     sub_graph_0->nodeIndices = {0, 1, 2, 3};
141     sub_graph_0->tensorIndices = {0, 1, 2, 3, 4, 9, 13, 14};
142     meta_graph->subGraph.emplace_back(std::move(sub_graph_0));
143   }
144   {    // subgraph-1
145     {  // add-2
146       auto add_2 = std::make_unique<schema::CNodeT>();
147       add_2->inputIndex = {4, 5};
148       add_2->outputIndex = {6};
149       add_2->primitive = std::make_unique<schema::PrimitiveT>();
150       add_2->primitive->value.type = schema::PrimitiveType_AddFusion;
151       auto add_2_prim = new schema::AddFusionT;
152       add_2_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
153       add_2->primitive->value.value = add_2_prim;
154       add_2->name = "Add2";
155       auto tensor_5 = std::make_unique<schema::TensorT>();
156       tensor_5->nodeType = lite::NodeType_ValueNode;
157       tensor_5->format = schema::Format_NHWC;
158       tensor_5->dataType = TypeId::kNumberTypeFloat32;
159       tensor_5->dims = {1};
160       tensor_5->data.resize(sizeof(float));
161       auto data5 = reinterpret_cast<float *>(tensor_5->data.data());
162       ASSERT_NE(data5, nullptr);
163       data5[0] = 1;
164       auto tensor_6 = std::make_unique<schema::TensorT>();
165       tensor_6->nodeType = lite::NodeType_Parameter;
166       tensor_6->format = schema::Format_NHWC;
167       tensor_6->dataType = TypeId::kNumberTypeFloat32;
168       meta_graph->nodes.emplace_back(std::move(add_2));
169       meta_graph->allTensors[5] = std::move(tensor_5);
170       meta_graph->allTensors[6] = std::move(tensor_6);
171     }
172     {  // less
173       auto less = std::make_unique<schema::CNodeT>();
174       less->inputIndex = {6, 15};
175       less->outputIndex = {7};
176       less->primitive = std::make_unique<schema::PrimitiveT>();
177       less->primitive->value.type = schema::PrimitiveType_Less;
178       auto less_prim = new schema::LessT;
179       less->primitive->value.value = less_prim;
180       less->name = "less";
181       auto tensor_15 = std::make_unique<schema::TensorT>();
182       tensor_15->nodeType = lite::NodeType_ValueNode;
183       tensor_15->format = schema::Format_NHWC;
184       tensor_15->dataType = TypeId::kNumberTypeFloat32;
185       tensor_15->dims = {1};
186       tensor_15->data.resize(sizeof(float));
187       auto data15 = reinterpret_cast<float *>(tensor_15->data.data());
188       ASSERT_NE(data15, nullptr);
189       data15[0] = 1;
190       auto tensor_7 = std::make_unique<schema::TensorT>();
191       tensor_7->nodeType = lite::NodeType_Parameter;
192       tensor_7->format = schema::Format_NHWC;
193       tensor_7->dataType = TypeId::kNumberTypeFloat32;
194       meta_graph->nodes.emplace_back(std::move(less));
195       meta_graph->allTensors[7] = std::move(tensor_7);
196       meta_graph->allTensors[15] = std::move(tensor_15);
197     }
198     {  // switch
199       auto switchop = std::make_unique<schema::CNodeT>();
200       switchop->inputIndex = {7, 4};
201       switchop->outputIndex = {8, 9};
202       switchop->primitive = std::make_unique<schema::PrimitiveT>();
203       switchop->primitive->value.type = schema::PrimitiveType_Switch;
204       auto switch_prim = new schema::SwitchT;
205       switchop->primitive->value.value = switch_prim;
206       switchop->name = "switch";
207       auto tensor_8 = std::make_unique<schema::TensorT>();
208       tensor_8->nodeType = lite::NodeType_Parameter;
209       tensor_8->format = schema::Format_NHWC;
210       tensor_8->dataType = TypeId::kNumberTypeFloat32;
211       auto tensor_9 = std::make_unique<schema::TensorT>();
212       tensor_9->nodeType = lite::NodeType_Parameter;
213       tensor_9->format = schema::Format_NHWC;
214       tensor_9->dataType = TypeId::kNumberTypeFloat32;
215       meta_graph->nodes.emplace_back(std::move(switchop));
216       meta_graph->allTensors[8] = std::move(tensor_8);
217       meta_graph->allTensors[9] = std::move(tensor_9);
218     }
219     {  // partial body
220       auto partial_body = std::make_unique<schema::CNodeT>();
221       partial_body->inputIndex = {8};
222       partial_body->outputIndex = {4};
223       partial_body->primitive = std::make_unique<schema::PrimitiveT>();
224       partial_body->primitive->value.type = schema::PrimitiveType_PartialFusion;
225       auto partial_body_prim = new schema::PartialFusionT;
226       partial_body_prim->sub_graph_index = 2;
227       partial_body->primitive->value.value = partial_body_prim;
228       partial_body->name = "partial_body";
229       meta_graph->nodes.emplace_back(std::move(partial_body));
230     }
231     auto sub_graph_1 = std::make_unique<schema::SubGraphT>();
232     sub_graph_1->name = "while_cond";
233     sub_graph_1->inputIndices = {4};
234     sub_graph_1->outputIndices = {9};
235     sub_graph_1->nodeIndices = {4, 5, 6, 7};
236     sub_graph_1->tensorIndices = {4, 5, 6, 7, 8, 9, 15};
237     meta_graph->subGraph.emplace_back(std::move(sub_graph_1));
238   }
239   {    // subgraph-2
240     {  // add-3
241       auto add_3 = std::make_unique<schema::CNodeT>();
242       add_3->inputIndex = {8, 10};
243       add_3->outputIndex = {11};
244       add_3->primitive = std::make_unique<schema::PrimitiveT>();
245       add_3->primitive->value.type = schema::PrimitiveType_AddFusion;
246       auto add_3_prim = new schema::AddFusionT;
247       add_3_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
248       add_3->primitive->value.value = add_3_prim;
249       add_3->name = "Add3";
250       auto tensor_10 = std::make_unique<schema::TensorT>();
251       tensor_10->nodeType = lite::NodeType_ValueNode;
252       tensor_10->format = schema::Format_NHWC;
253       tensor_10->dataType = TypeId::kNumberTypeFloat32;
254       tensor_10->dims = {1};
255       tensor_10->data.resize(sizeof(float));
256       auto data10 = reinterpret_cast<float *>(tensor_10->data.data());
257       ASSERT_NE(data10, nullptr);
258       data10[0] = 1;
259       auto tensor_11 = std::make_unique<schema::TensorT>();
260       tensor_11->nodeType = lite::NodeType_Parameter;
261       tensor_11->format = schema::Format_NHWC;
262       tensor_11->dataType = TypeId::kNumberTypeFloat32;
263       meta_graph->nodes.emplace_back(std::move(add_3));
264       meta_graph->allTensors[10] = std::move(tensor_10);
265       meta_graph->allTensors[11] = std::move(tensor_11);
266     }
267     {  // add-4
268       auto add_4 = std::make_unique<schema::CNodeT>();
269       add_4->inputIndex = {11, 12};
270       add_4->outputIndex = {4};
271       add_4->primitive = std::make_unique<schema::PrimitiveT>();
272       add_4->primitive->value.type = schema::PrimitiveType_AddFusion;
273       auto add_4_prim = new schema::AddFusionT;
274       add_4_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
275       add_4->primitive->value.value = add_4_prim;
276       add_4->name = "Add4";
277       auto tensor_12 = std::make_unique<schema::TensorT>();
278       tensor_12->nodeType = lite::NodeType_ValueNode;
279       tensor_12->format = schema::Format_NHWC;
280       tensor_12->dataType = TypeId::kNumberTypeFloat32;
281       tensor_12->dims = {1};
282       tensor_12->data.resize(sizeof(float));
283       auto data12 = reinterpret_cast<float *>(tensor_12->data.data());
284       ASSERT_NE(data12, nullptr);
285       data12[0] = 1;
286       meta_graph->nodes.emplace_back(std::move(add_4));
287       meta_graph->allTensors[12] = std::move(tensor_12);
288     }
289     {  // partial cond
290       auto partial_cond = std::make_unique<schema::CNodeT>();
291       partial_cond->inputIndex = {4};
292       partial_cond->outputIndex = {9};
293       partial_cond->primitive = std::make_unique<schema::PrimitiveT>();
294       partial_cond->primitive->value.type = schema::PrimitiveType_PartialFusion;
295       auto partial_cond_prim = new schema::PartialFusionT;
296       partial_cond_prim->sub_graph_index = 1;
297       partial_cond->primitive->value.value = partial_cond_prim;
298       partial_cond->name = "partial_cond1";
299       meta_graph->nodes.emplace_back(std::move(partial_cond));
300     }
301     auto sub_graph_2 = std::make_unique<schema::SubGraphT>();
302     sub_graph_2->name = "while_body";
303     sub_graph_2->inputIndices = {8};
304     sub_graph_2->outputIndices = {9};
305     sub_graph_2->nodeIndices = {8, 9, 10};
306     sub_graph_2->tensorIndices = {8, 10, 11, 12, 4, 9};
307     meta_graph->subGraph.emplace_back(std::move(sub_graph_2));
308   }
309   meta_graph->name = "graph";
310   meta_graph->version = Version();
311   //  -----------------------------------------------------------------------
312   lite::MetaGraphSerializer::Save(
313     *meta_graph, "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph");
314   //  -----------------------------------------------------------------------
315   size_t size = 0;
316   char *graph_buf = lite::ReadFile(
317     "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph.ms", &size);
318   ASSERT_NE(graph_buf, nullptr);
319 
320   auto model = std::shared_ptr<lite::Model>(lite::Model::Import(graph_buf, size));
321   ASSERT_NE(model, nullptr);
322   delete[](graph_buf);
323   auto context = std::make_shared<mindspore::lite::InnerContext>();
324   auto &cpu_device_ctx = context->device_list_[0];
325   cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU;
326   context->thread_num_ = 2;
327   auto session = std::shared_ptr<lite::LiteSession>(lite::LiteSession::CreateSession(context));
328   ASSERT_NE(session, nullptr);
329   auto ret = session->CompileGraph(model.get());
330   ASSERT_EQ(ret, lite::RET_OK);
331   auto inputs = session->GetInputs();
332   for (auto *input : inputs) {
333     (void)input->MutableData();
334   }
335   ret = session->RunGraph();
336   ASSERT_EQ(ret, lite::RET_OK);
337 }
338 }  // namespace mindspore
339