• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <cmath>
18 #include <memory>
19 #include "schema/inner/model_generated.h"
20 #include "mindspore/lite/include/model.h"
21 #include "common/common_test.h"
22 #include "include/lite_session.h"
23 #include "include/context.h"
24 #include "include/model.h"
25 #include "include/errorcode.h"
26 #include "src/common/log_adapter.h"
27 #include "src/lite_session.h"
28 #include "tools/common/storage.h"
29 #include "include/version.h"
30 
31 namespace mindspore {
32 class SubGraphTest : public mindspore::CommonTest {
33  public:
34   SubGraphTest() = default;
35 };
36 
TEST_F(SubGraphTest,RecursiveSubGraphTest)37 TEST_F(SubGraphTest, RecursiveSubGraphTest) {
38   auto meta_graph = std::make_shared<schema::MetaGraphT>();
39   meta_graph->allTensors.resize(16);
40   {    // subgraph-0
41     {  // add-0
42       auto add_0 = std::make_unique<schema::CNodeT>();
43       add_0->inputIndex = {0, 1};
44       add_0->outputIndex = {2};
45       add_0->primitive = std::make_unique<schema::PrimitiveT>();
46       add_0->primitive->value.type = schema::PrimitiveType_AddFusion;
47       auto add_0_prim = new schema::AddFusionT;
48       add_0_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
49       add_0->primitive->value.value = add_0_prim;
50       add_0->name = "Add0";
51       auto tensor_0 = std::make_unique<schema::TensorT>();
52       tensor_0->nodeType = lite::NodeType_ValueNode;
53       tensor_0->format = schema::Format_NHWC;
54       tensor_0->dataType = TypeId::kNumberTypeFloat32;
55       tensor_0->dims = {1};
56       auto tensor_1 = std::make_unique<schema::TensorT>();
57       tensor_1->nodeType = lite::NodeType_ValueNode;
58       tensor_1->format = schema::Format_NHWC;
59       tensor_1->dataType = TypeId::kNumberTypeFloat32;
60       tensor_1->dims = {1};
61       tensor_1->data.resize(sizeof(float));
62       auto data1 = reinterpret_cast<float *>(tensor_1->data.data());
63       ASSERT_NE(data1, nullptr);
64       data1[0] = 1;
65       auto tensor_2 = std::make_unique<schema::TensorT>();
66       tensor_2->nodeType = lite::NodeType_Parameter;
67       tensor_2->format = schema::Format_NHWC;
68       tensor_2->dataType = TypeId::kNumberTypeFloat32;
69       meta_graph->nodes.emplace_back(std::move(add_0));
70       meta_graph->allTensors[0] = std::move(tensor_0);
71       meta_graph->allTensors[1] = std::move(tensor_1);
72       meta_graph->allTensors[2] = std::move(tensor_2);
73     }
74     {  // add-1
75       auto add_1 = std::make_unique<schema::CNodeT>();
76       add_1->inputIndex = {2, 3};
77       add_1->outputIndex = {4};
78       add_1->primitive = std::make_unique<schema::PrimitiveT>();
79       add_1->primitive->value.type = schema::PrimitiveType_AddFusion;
80       auto add_1_prim = new schema::AddFusionT;
81       add_1_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
82       add_1->primitive->value.value = add_1_prim;
83       add_1->name = "Add1";
84       auto tensor_3 = std::make_unique<schema::TensorT>();
85       tensor_3->nodeType = lite::NodeType_ValueNode;
86       tensor_3->format = schema::Format_NHWC;
87       tensor_3->dataType = TypeId::kNumberTypeFloat32;
88       tensor_3->dims = {1};
89       tensor_3->data.resize(sizeof(float));
90       auto data3 = reinterpret_cast<float *>(tensor_3->data.data());
91       ASSERT_NE(data3, nullptr);
92       data3[0] = 1;
93       auto tensor_4 = std::make_unique<schema::TensorT>();
94       tensor_4->nodeType = lite::NodeType_Parameter;
95       tensor_4->format = schema::Format_NHWC;
96       tensor_4->dataType = TypeId::kNumberTypeFloat32;
97       meta_graph->nodes.emplace_back(std::move(add_1));
98       meta_graph->allTensors[3] = std::move(tensor_3);
99       meta_graph->allTensors[4] = std::move(tensor_4);
100     }
101     {  // partial cond
102       auto partial_cond = std::make_unique<schema::CNodeT>();
103       partial_cond->inputIndex = {4};
104       partial_cond->outputIndex = {9};
105       partial_cond->primitive = std::make_unique<schema::PrimitiveT>();
106       partial_cond->primitive->value.type = schema::PrimitiveType_PartialFusion;
107       auto partial_cond_prim = new schema::PartialFusionT;
108       partial_cond_prim->sub_graph_index = 1;
109       partial_cond->primitive->value.value = partial_cond_prim;
110       partial_cond->name = "partial_cond";
111       meta_graph->nodes.emplace_back(std::move(partial_cond));
112     }
113     {  // add-5
114       auto add_5 = std::make_unique<schema::CNodeT>();
115       add_5->inputIndex = {9, 13};
116       add_5->outputIndex = {14};
117       add_5->primitive = std::make_unique<schema::PrimitiveT>();
118       add_5->primitive->value.type = schema::PrimitiveType_AddFusion;
119       auto add_5_prim = new schema::AddFusionT;
120       add_5_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
121       add_5->primitive->value.value = add_5_prim;
122       add_5->name = "Add5";
123       auto tensor_13 = std::make_unique<schema::TensorT>();
124       tensor_13->nodeType = lite::NodeType_ValueNode;
125       tensor_13->format = schema::Format_NHWC;
126       tensor_13->dataType = TypeId::kNumberTypeFloat32;
127       tensor_13->dims = {1};
128       tensor_13->data.resize(sizeof(float));
129       auto data13 = reinterpret_cast<float *>(tensor_13->data.data());
130       ASSERT_NE(data13, nullptr);
131       data13[0] = 1;
132       auto tensor_14 = std::make_unique<schema::TensorT>();
133       tensor_14->nodeType = lite::NodeType_Parameter;
134       tensor_14->format = schema::Format_NHWC;
135       tensor_14->dataType = TypeId::kNumberTypeFloat32;
136       meta_graph->nodes.emplace_back(std::move(add_5));
137       meta_graph->allTensors[13] = std::move(tensor_13);
138       meta_graph->allTensors[14] = std::move(tensor_14);
139     }
140     auto sub_graph_0 = std::make_unique<schema::SubGraphT>();
141     sub_graph_0->name = "main_graph";
142     sub_graph_0->inputIndices = {0};
143     sub_graph_0->outputIndices = {14};
144     sub_graph_0->nodeIndices = {0, 1, 2, 3};
145     sub_graph_0->tensorIndices = {0, 1, 2, 3, 4, 9, 13, 14};
146     meta_graph->subGraph.emplace_back(std::move(sub_graph_0));
147   }
148   {    // subgraph-1
149     {  // add-2
150       auto add_2 = std::make_unique<schema::CNodeT>();
151       add_2->inputIndex = {4, 5};
152       add_2->outputIndex = {6};
153       add_2->primitive = std::make_unique<schema::PrimitiveT>();
154       add_2->primitive->value.type = schema::PrimitiveType_AddFusion;
155       auto add_2_prim = new schema::AddFusionT;
156       add_2_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
157       add_2->primitive->value.value = add_2_prim;
158       add_2->name = "Add2";
159       auto tensor_5 = std::make_unique<schema::TensorT>();
160       tensor_5->nodeType = lite::NodeType_ValueNode;
161       tensor_5->format = schema::Format_NHWC;
162       tensor_5->dataType = TypeId::kNumberTypeFloat32;
163       tensor_5->dims = {1};
164       tensor_5->data.resize(sizeof(float));
165       auto data5 = reinterpret_cast<float *>(tensor_5->data.data());
166       ASSERT_NE(data5, nullptr);
167       data5[0] = 1;
168       auto tensor_6 = std::make_unique<schema::TensorT>();
169       tensor_6->nodeType = lite::NodeType_Parameter;
170       tensor_6->format = schema::Format_NHWC;
171       tensor_6->dataType = TypeId::kNumberTypeFloat32;
172       meta_graph->nodes.emplace_back(std::move(add_2));
173       meta_graph->allTensors[5] = std::move(tensor_5);
174       meta_graph->allTensors[6] = std::move(tensor_6);
175     }
176     {  // less
177       auto less = std::make_unique<schema::CNodeT>();
178       less->inputIndex = {6, 15};
179       less->outputIndex = {7};
180       less->primitive = std::make_unique<schema::PrimitiveT>();
181       less->primitive->value.type = schema::PrimitiveType_Less;
182       auto less_prim = new schema::LessT;
183       less->primitive->value.value = less_prim;
184       less->name = "less";
185       auto tensor_15 = std::make_unique<schema::TensorT>();
186       tensor_15->nodeType = lite::NodeType_ValueNode;
187       tensor_15->format = schema::Format_NHWC;
188       tensor_15->dataType = TypeId::kNumberTypeFloat32;
189       tensor_15->dims = {1};
190       tensor_15->data.resize(sizeof(float));
191       auto data15 = reinterpret_cast<float *>(tensor_15->data.data());
192       ASSERT_NE(data15, nullptr);
193       data15[0] = 1;
194       auto tensor_7 = std::make_unique<schema::TensorT>();
195       tensor_7->nodeType = lite::NodeType_Parameter;
196       tensor_7->format = schema::Format_NHWC;
197       tensor_7->dataType = TypeId::kNumberTypeFloat32;
198       meta_graph->nodes.emplace_back(std::move(less));
199       meta_graph->allTensors[7] = std::move(tensor_7);
200       meta_graph->allTensors[15] = std::move(tensor_15);
201     }
202     {  // switch
203       auto switchop = std::make_unique<schema::CNodeT>();
204       switchop->inputIndex = {7, 4};
205       switchop->outputIndex = {8, 9};
206       switchop->primitive = std::make_unique<schema::PrimitiveT>();
207       switchop->primitive->value.type = schema::PrimitiveType_Switch;
208       auto switch_prim = new schema::SwitchT;
209       switchop->primitive->value.value = switch_prim;
210       switchop->name = "switch";
211       auto tensor_8 = std::make_unique<schema::TensorT>();
212       tensor_8->nodeType = lite::NodeType_Parameter;
213       tensor_8->format = schema::Format_NHWC;
214       tensor_8->dataType = TypeId::kNumberTypeFloat32;
215       auto tensor_9 = std::make_unique<schema::TensorT>();
216       tensor_9->nodeType = lite::NodeType_Parameter;
217       tensor_9->format = schema::Format_NHWC;
218       tensor_9->dataType = TypeId::kNumberTypeFloat32;
219       meta_graph->nodes.emplace_back(std::move(switchop));
220       meta_graph->allTensors[8] = std::move(tensor_8);
221       meta_graph->allTensors[9] = std::move(tensor_9);
222     }
223     {  // partial body
224       auto partial_body = std::make_unique<schema::CNodeT>();
225       partial_body->inputIndex = {8};
226       partial_body->outputIndex = {4};
227       partial_body->primitive = std::make_unique<schema::PrimitiveT>();
228       partial_body->primitive->value.type = schema::PrimitiveType_PartialFusion;
229       auto partial_body_prim = new schema::PartialFusionT;
230       partial_body_prim->sub_graph_index = 2;
231       partial_body->primitive->value.value = partial_body_prim;
232       partial_body->name = "partial_body";
233       meta_graph->nodes.emplace_back(std::move(partial_body));
234     }
235     auto sub_graph_1 = std::make_unique<schema::SubGraphT>();
236     sub_graph_1->name = "while_cond";
237     sub_graph_1->inputIndices = {4};
238     sub_graph_1->outputIndices = {9};
239     sub_graph_1->nodeIndices = {4, 5, 6, 7};
240     sub_graph_1->tensorIndices = {4, 5, 6, 7, 8, 9, 15};
241     meta_graph->subGraph.emplace_back(std::move(sub_graph_1));
242   }
243   {    // subgraph-2
244     {  // add-3
245       auto add_3 = std::make_unique<schema::CNodeT>();
246       add_3->inputIndex = {8, 10};
247       add_3->outputIndex = {11};
248       add_3->primitive = std::make_unique<schema::PrimitiveT>();
249       add_3->primitive->value.type = schema::PrimitiveType_AddFusion;
250       auto add_3_prim = new schema::AddFusionT;
251       add_3_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
252       add_3->primitive->value.value = add_3_prim;
253       add_3->name = "Add3";
254       auto tensor_10 = std::make_unique<schema::TensorT>();
255       tensor_10->nodeType = lite::NodeType_ValueNode;
256       tensor_10->format = schema::Format_NHWC;
257       tensor_10->dataType = TypeId::kNumberTypeFloat32;
258       tensor_10->dims = {1};
259       tensor_10->data.resize(sizeof(float));
260       auto data10 = reinterpret_cast<float *>(tensor_10->data.data());
261       ASSERT_NE(data10, nullptr);
262       data10[0] = 1;
263       auto tensor_11 = std::make_unique<schema::TensorT>();
264       tensor_11->nodeType = lite::NodeType_Parameter;
265       tensor_11->format = schema::Format_NHWC;
266       tensor_11->dataType = TypeId::kNumberTypeFloat32;
267       meta_graph->nodes.emplace_back(std::move(add_3));
268       meta_graph->allTensors[10] = std::move(tensor_10);
269       meta_graph->allTensors[11] = std::move(tensor_11);
270     }
271     {  // add-4
272       auto add_4 = std::make_unique<schema::CNodeT>();
273       add_4->inputIndex = {11, 12};
274       add_4->outputIndex = {4};
275       add_4->primitive = std::make_unique<schema::PrimitiveT>();
276       add_4->primitive->value.type = schema::PrimitiveType_AddFusion;
277       auto add_4_prim = new schema::AddFusionT;
278       add_4_prim->activation_type = schema::ActivationType_NO_ACTIVATION;
279       add_4->primitive->value.value = add_4_prim;
280       add_4->name = "Add4";
281       auto tensor_12 = std::make_unique<schema::TensorT>();
282       tensor_12->nodeType = lite::NodeType_ValueNode;
283       tensor_12->format = schema::Format_NHWC;
284       tensor_12->dataType = TypeId::kNumberTypeFloat32;
285       tensor_12->dims = {1};
286       tensor_12->data.resize(sizeof(float));
287       auto data12 = reinterpret_cast<float *>(tensor_12->data.data());
288       ASSERT_NE(data12, nullptr);
289       data12[0] = 1;
290       meta_graph->nodes.emplace_back(std::move(add_4));
291       meta_graph->allTensors[12] = std::move(tensor_12);
292     }
293     {  // partial cond
294       auto partial_cond = std::make_unique<schema::CNodeT>();
295       partial_cond->inputIndex = {4};
296       partial_cond->outputIndex = {9};
297       partial_cond->primitive = std::make_unique<schema::PrimitiveT>();
298       partial_cond->primitive->value.type = schema::PrimitiveType_PartialFusion;
299       auto partial_cond_prim = new schema::PartialFusionT;
300       partial_cond_prim->sub_graph_index = 1;
301       partial_cond->primitive->value.value = partial_cond_prim;
302       partial_cond->name = "partial_cond1";
303       meta_graph->nodes.emplace_back(std::move(partial_cond));
304     }
305     auto sub_graph_2 = std::make_unique<schema::SubGraphT>();
306     sub_graph_2->name = "while_body";
307     sub_graph_2->inputIndices = {8};
308     sub_graph_2->outputIndices = {9};
309     sub_graph_2->nodeIndices = {8, 9, 10};
310     sub_graph_2->tensorIndices = {8, 10, 11, 12, 4, 9};
311     meta_graph->subGraph.emplace_back(std::move(sub_graph_2));
312   }
313   meta_graph->name = "graph";
314   meta_graph->version = lite::Version();
315   //  -----------------------------------------------------------------------
316   lite::Storage::Save(*meta_graph,
317                       "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph");
318   //  -----------------------------------------------------------------------
319   size_t size = 0;
320   char *graph_buf = lite::ReadFile(
321     "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph.ms", &size);
322   ASSERT_NE(graph_buf, nullptr);
323 
324   auto model = std::shared_ptr<lite::Model>(lite::Model::Import(graph_buf, size));
325   ASSERT_NE(model, nullptr);
326   delete[](graph_buf);
327   lite::Context context;
328   auto &cpu_device_ctx = context.device_list_[0];
329   cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU;
330   context.thread_num_ = 2;
331   auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&context));
332   ASSERT_NE(session, nullptr);
333   auto ret = session->CompileGraph(model.get());
334   ASSERT_EQ(ret, lite::RET_OK);
335   auto inputs = session->GetInputs();
336   for (auto *input : inputs) {
337     (void)input->MutableData();
338   }
339   ret = session->RunGraph();
340   ASSERT_EQ(ret, lite::RET_OK);
341 }
342 }  // namespace mindspore
343