1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "common/common_test.h"
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "mindspore/core/ops/nn_ops.h"
20 #include "mindspore/core/ops/math_ops.h"
21 #include "mindspore/core/ops/framework_ops.h"
22 #include "ir/param_info.h"
23 #include "frontend/operator/ops.h"
24 #include "include/backend/kernel_graph.h"
25 #include "include/backend/anf_runtime_algorithm.h"
26 #include "mindspore/ccsrc/include/backend/kernel_info.h"
27 #include "mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_device_address.h"
28 #include "include/common/utils/utils.h"
29 #include "include/common/utils/anfalgo.h"
30
31 namespace mindspore {
32 namespace session {
33 namespace {
34 constexpr auto kPatternConvolution = "Convolution";
35 }
36
37 using device::KernelInfo;
38 using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
39 using AscendDeviceAddress = device::ascend::AscendDeviceAddress;
40
41 class AnfRuntimeAlgorithmTest : public UT::Common {
42 public:
43 AnfRuntimeAlgorithmTest() = default;
SetUp()44 void SetUp() override {}
TearDown()45 void TearDown() override {}
46 };
47
TEST_F(AnfRuntimeAlgorithmTest,VisitKernel)48 TEST_F(AnfRuntimeAlgorithmTest, VisitKernel) {
49 auto kernel_graph = std::make_shared<KernelGraph>();
50 KernelWithIndex kernel_with_index;
51 // test nullptr as input
52 EXPECT_THROW(common::AnfAlgo::VisitKernel(nullptr, 0), std::runtime_error);
53 // test value node as input
54 ValueNodePtr value_node = NewValueNode(prim::kPrimAdd);
55 kernel_with_index = common::AnfAlgo::VisitKernel(value_node, 0);
56 EXPECT_NE(kernel_with_index.first->cast<ValueNodePtr>(), nullptr);
57 EXPECT_EQ((kernel_with_index.first->cast<ValueNodePtr>()).get(), value_node.get());
58 EXPECT_EQ(kernel_with_index.second, 0);
59 // test parameter node as input
60 ParameterPtr parameter_node = kernel_graph->add_parameter();
61 kernel_with_index = common::AnfAlgo::VisitKernel(parameter_node, 0);
62 EXPECT_NE(kernel_with_index.first->cast<ParameterPtr>(), nullptr);
63 EXPECT_EQ((kernel_with_index.first->cast<ParameterPtr>()).get(), parameter_node.get());
64 EXPECT_EQ(kernel_with_index.second, 0);
65 // test cnode as input
66 std::vector<AnfNodePtr> inputs{value_node};
67 auto add = kernel_graph->NewCNode(inputs);
68 kernel_with_index = common::AnfAlgo::VisitKernel(add, 0);
69 EXPECT_NE(kernel_with_index.first->cast<CNodePtr>(), nullptr);
70 EXPECT_EQ((kernel_with_index.first->cast<CNodePtr>()).get(), add.get());
71 EXPECT_EQ(kernel_with_index.second, 0);
72 // test maketuple node as input
73 std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimAdd)};
74 auto add_second = kernel_graph->NewCNode(add_inputs);
75 std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), add, add_second};
76 auto make_tuple = kernel_graph->NewCNode(make_tuple_inputs);
77 MS_EXCEPTION_IF_NULL(make_tuple);
78 std::vector<int64_t> shp{2, 32, 224, 224};
79 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
80 AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
81 make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(args_spec_list));
82 kernel_with_index = common::AnfAlgo::VisitKernel(make_tuple, 0);
83 EXPECT_NE(kernel_with_index.first->cast<CNodePtr>(), nullptr);
84 EXPECT_EQ((kernel_with_index.first->cast<CNodePtr>()).get(), add.get());
85 EXPECT_EQ(kernel_with_index.second, 0);
86 kernel_with_index = common::AnfAlgo::VisitKernel(make_tuple, 1);
87 EXPECT_NE(kernel_with_index.first->cast<CNodePtr>(), nullptr);
88 EXPECT_EQ((kernel_with_index.first->cast<CNodePtr>()).get(), add_second.get());
89 EXPECT_EQ(kernel_with_index.second, 0);
90 // test tuple get item node as input
91 std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), make_tuple,
92 NewValueNode(static_cast<int64_t>(1))};
93 auto tuple_get_item = kernel_graph->NewCNode(tuple_get_item_inputs);
94 kernel_with_index = common::AnfAlgo::VisitKernel(tuple_get_item, 0);
95 EXPECT_NE(kernel_with_index.first->cast<CNodePtr>(), nullptr);
96 EXPECT_EQ((kernel_with_index.first->cast<CNodePtr>()).get(), add_second.get());
97 EXPECT_EQ(kernel_with_index.second, 0);
98 // test depend node as input
99 std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), add, add_second};
100 auto depend = kernel_graph->NewCNode(depend_inputs);
101 kernel_with_index = common::AnfAlgo::VisitKernel(depend, 0);
102 EXPECT_NE(kernel_with_index.first->cast<CNodePtr>(), nullptr);
103 EXPECT_EQ((kernel_with_index.first->cast<CNodePtr>()).get(), add.get());
104 EXPECT_EQ(kernel_with_index.second, 0);
105 }
106
TEST_F(AnfRuntimeAlgorithmTest,GetCNodePrimitive)107 TEST_F(AnfRuntimeAlgorithmTest, GetCNodePrimitive) {
108 auto kernel_graph = std::make_shared<KernelGraph>();
109 // test cnode node
110 PrimitivePtr add_primitive = prim::kPrimAdd;
111 std::vector<AnfNodePtr> inputs{NewValueNode(add_primitive)};
112 auto add = kernel_graph->NewCNode(inputs);
113 EXPECT_NE(common::AnfAlgo::GetCNodePrimitive(add), nullptr);
114 EXPECT_EQ(common::AnfAlgo::GetCNodePrimitive(add).get(), add_primitive.get());
115 EXPECT_THROW(common::AnfAlgo::GetCNodePrimitive(nullptr), std::runtime_error);
116 }
117
TEST_F(AnfRuntimeAlgorithmTest,GetCNodeName)118 TEST_F(AnfRuntimeAlgorithmTest, GetCNodeName) {
119 auto kernel_graph = std::make_shared<KernelGraph>();
120 // test cnode node
121 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimAdd)};
122 auto add = kernel_graph->NewCNode(inputs);
123 EXPECT_EQ(common::AnfAlgo::GetCNodeName(add), prim::kPrimAdd->name());
124 EXPECT_THROW(common::AnfAlgo::GetCNodeName(nullptr), std::runtime_error);
125 // test parameter
126 auto parameter_node = kernel_graph->add_parameter();
127 EXPECT_THROW(common::AnfAlgo::GetCNodeName(parameter_node), std::runtime_error);
128 }
129
TEST_F(AnfRuntimeAlgorithmTest,GetNodeDebugString)130 TEST_F(AnfRuntimeAlgorithmTest, GetNodeDebugString) {
131 auto kernel_graph = std::make_shared<KernelGraph>();
132 // test cnode node
133 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimAdd)};
134 auto add = kernel_graph->NewCNode(inputs);
135 EXPECT_EQ(common::AnfAlgo::GetNodeDebugString(add), add->DebugString());
136 EXPECT_THROW(common::AnfAlgo::GetNodeDebugString(nullptr), std::runtime_error);
137 }
138
TEST_F(AnfRuntimeAlgorithmTest,SetNodeAttr)139 TEST_F(AnfRuntimeAlgorithmTest, SetNodeAttr) {
140 auto kernel_graph = std::make_shared<KernelGraph>();
141 // test cnode node
142 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimAdd)};
143 auto add = kernel_graph->NewCNode(inputs);
144 common::AnfAlgo::SetNodeAttr("test_set_attr", MakeValue("test_value"), add);
145 auto primitive = common::AnfAlgo::GetCNodePrimitive(add);
146 MS_EXCEPTION_IF_NULL(primitive);
147 EXPECT_EQ(GetValue<std::string>(primitive->GetAttr("test_set_attr")), "test_value");
148 // test parameter node
149 auto parameter = kernel_graph->add_parameter();
150 EXPECT_THROW(common::AnfAlgo::SetNodeAttr("test_set_attr", MakeValue("test_value"), parameter), std::runtime_error);
151 }
152
TEST_F(AnfRuntimeAlgorithmTest,CopyNodeAttr)153 TEST_F(AnfRuntimeAlgorithmTest, CopyNodeAttr) {
154 auto kernel_graph = std::make_shared<KernelGraph>();
155 // test cnode node
156 std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimAdd)};
157 auto add = kernel_graph->NewCNode(add_inputs);
158 common::AnfAlgo::SetNodeAttr("test_set_attr", MakeValue("test_value"), add);
159
160 std::vector<AnfNodePtr> mul_inputs{NewValueNode(prim::kPrimMul)};
161 auto mul = kernel_graph->NewCNode(mul_inputs);
162 common::AnfAlgo::SetNodeAttr("test_set_attr", MakeValue("test_value_v2"), mul);
163 common::AnfAlgo::CopyNodeAttr("test_set_attr", mul, add);
164 auto primitive = common::AnfAlgo::GetCNodePrimitive(add);
165 MS_EXCEPTION_IF_NULL(primitive);
166 EXPECT_EQ(GetValue<std::string>(primitive->GetAttr("test_set_attr")), "test_value_v2");
167 // test parameter node
168 auto parameter = kernel_graph->add_parameter();
169 EXPECT_THROW(common::AnfAlgo::CopyNodeAttr("test_set_attr", parameter, add), std::runtime_error);
170 EXPECT_THROW(common::AnfAlgo::CopyNodeAttr("test_set_attr", mul, parameter), std::runtime_error);
171 EXPECT_THROW(common::AnfAlgo::CopyNodeAttr("test_set_attr", parameter, parameter), std::runtime_error);
172 EXPECT_THROW(common::AnfAlgo::CopyNodeAttr("test_set_attr", nullptr, add), std::runtime_error);
173 EXPECT_THROW(common::AnfAlgo::CopyNodeAttr("test_set_attr", mul, nullptr), std::runtime_error);
174 EXPECT_THROW(common::AnfAlgo::CopyNodeAttr("test_set_attr", nullptr, nullptr), std::runtime_error);
175 }
176
TEST_F(AnfRuntimeAlgorithmTest,CopyNodeAttrs)177 TEST_F(AnfRuntimeAlgorithmTest, CopyNodeAttrs) {
178 auto kernel_graph = std::make_shared<KernelGraph>();
179 // test cnode node
180 std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimAdd)};
181 auto add = kernel_graph->NewCNode(add_inputs);
182 common::AnfAlgo::SetNodeAttr("test_set_attr", MakeValue("test_value"), add);
183
184 std::vector<AnfNodePtr> mul_inputs{NewValueNode(prim::kPrimMul)};
185 auto mul = kernel_graph->NewCNode(mul_inputs);
186 common::AnfAlgo::SetNodeAttr("test_set_attr", MakeValue("test_value_v2"), mul);
187 common::AnfAlgo::CopyNodeAttrs(mul, add);
188 auto primitive = common::AnfAlgo::GetCNodePrimitive(add);
189 MS_EXCEPTION_IF_NULL(primitive);
190 EXPECT_EQ(GetValue<std::string>(primitive->GetAttr("test_set_attr")), "test_value_v2");
191 // test parameter node
192 auto parameter = kernel_graph->add_parameter();
193 EXPECT_THROW(common::AnfAlgo::CopyNodeAttrs(parameter, add), std::runtime_error);
194 EXPECT_THROW(common::AnfAlgo::CopyNodeAttrs(mul, parameter), std::runtime_error);
195 EXPECT_THROW(common::AnfAlgo::CopyNodeAttrs(parameter, parameter), std::runtime_error);
196 EXPECT_THROW(common::AnfAlgo::CopyNodeAttrs(nullptr, add), std::runtime_error);
197 EXPECT_THROW(common::AnfAlgo::CopyNodeAttrs(mul, nullptr), std::runtime_error);
198 EXPECT_THROW(common::AnfAlgo::CopyNodeAttrs(nullptr, nullptr), std::runtime_error);
199 }
200
TEST_F(AnfRuntimeAlgorithmTest,EraseNodeAttr)201 TEST_F(AnfRuntimeAlgorithmTest, EraseNodeAttr) {
202 auto kernel_graph = std::make_shared<KernelGraph>();
203 // test cnode node
204 std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimAdd)};
205 auto add = kernel_graph->NewCNode(add_inputs);
206 common::AnfAlgo::SetNodeAttr("test_set_attr", MakeValue("test_value"), add);
207 common::AnfAlgo::SetNodeAttr("test_set_attr_v2", MakeValue("test_value_v2"), add);
208 common::AnfAlgo::EraseNodeAttr("test_set_attr_v2", add);
209 EXPECT_THROW(common::AnfAlgo::GetNodeAttr<std::string>(add, "test_set_attr_v2"), std::runtime_error);
210 EXPECT_THROW(common::AnfAlgo::EraseNodeAttr("test_set_attr_v2", nullptr), std::runtime_error);
211 // test parameter node
212 auto parameter = kernel_graph->add_parameter();
213 EXPECT_THROW(common::AnfAlgo::EraseNodeAttr("test_set_attr_v2", parameter), std::runtime_error);
214 }
215
TEST_F(AnfRuntimeAlgorithmTest,GetInputTensorNum)216 TEST_F(AnfRuntimeAlgorithmTest, GetInputTensorNum) {
217 auto kernel_graph = std::make_shared<KernelGraph>();
218 // test cnode node
219 auto parameter_one = kernel_graph->NewParameter();
220 auto parameter_two = kernel_graph->NewParameter();
221 std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimAdd), parameter_one, parameter_two};
222 auto add = kernel_graph->NewCNode(add_inputs);
223 EXPECT_EQ(common::AnfAlgo::GetInputTensorNum(add), 2);
224 EXPECT_THROW(common::AnfAlgo::GetInputTensorNum(nullptr), std::runtime_error);
225 // test parameter node
226 EXPECT_THROW(common::AnfAlgo::GetInputTensorNum(parameter_one), std::runtime_error);
227 }
228
TEST_F(AnfRuntimeAlgorithmTest,GetOutputTensorNum)229 TEST_F(AnfRuntimeAlgorithmTest, GetOutputTensorNum) {
230 auto kernel_graph = std::make_shared<KernelGraph>();
231 std::vector<AnfNodePtr> inputs;
232 // test fused batch norm as input
233 inputs.push_back(NewValueNode(prim::kPrimBatchNorm));
234 auto bn = kernel_graph->NewCNode(inputs);
235 MS_EXCEPTION_IF_NULL(bn);
236 std::vector<int64_t> shp{2, 32, 224, 224};
237 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
238 AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract};
239 bn->set_abstract(std::make_shared<abstract::AbstractTuple>(args_spec_list));
240 KernelBuildInfoBuilder builder;
241 builder.SetOutputsFormat(
242 {kOpFormat_DEFAULT, kOpFormat_DEFAULT, kOpFormat_DEFAULT, kOpFormat_DEFAULT, kOpFormat_DEFAULT});
243 builder.SetOutputsDeviceType(
244 {kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
245 builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TUPLE_UNFOLD});
246 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), bn.get());
247 EXPECT_EQ(AnfAlgo::GetOutputTensorNum(bn), 5);
248 EXPECT_THROW(AnfAlgo::GetOutputTensorNum(nullptr), std::runtime_error);
249 // test add as input
250 inputs.clear();
251 inputs.push_back(NewValueNode(prim::kPrimAdd));
252 auto add = kernel_graph->NewCNode(inputs);
253 MS_EXCEPTION_IF_NULL(add);
254 add->set_abstract(std::make_shared<abstract::AbstractNone>());
255 EXPECT_EQ(AnfAlgo::GetOutputTensorNum(add), 0);
256 add->set_abstract(x_abstract);
257 EXPECT_EQ(AnfAlgo::GetOutputTensorNum(add), 1);
258 }
259
TEST_F(AnfRuntimeAlgorithmTest,GetOutputFormat)260 TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) {
261 auto kernel_graph = std::make_shared<KernelGraph>();
262 std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimAdd), kernel_graph->NewParameter(),
263 kernel_graph->NewParameter()};
264 auto add = kernel_graph->NewCNode(inputs);
265 ShapeVector shape = {1, 2, 3, 4};
266 common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {shape, shape}, add.get());
267 MS_EXCEPTION_IF_NULL(add);
268 add->set_kernel_info(std::make_shared<KernelInfo>());
269 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
270 MS_EXCEPTION_IF_NULL(d_kernel_info);
271 KernelBuildInfoBuilder builder;
272 builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat16->type_id()});
273 builder.SetOutputsFormat({kOpFormat_NCHW, kOpFormat_NC1HWC0});
274 d_kernel_info->set_select_kernel_build_info(builder.Build());
275 EXPECT_EQ(AnfAlgo::GetOutputFormat(add, 0), kOpFormat_NCHW);
276 EXPECT_EQ(AnfAlgo::GetOutputFormat(add, 1), kOpFormat_NC1HWC0);
277 EXPECT_THROW(AnfAlgo::GetOutputFormat(add, 2), std::runtime_error);
278 EXPECT_THROW(AnfAlgo::GetOutputFormat(nullptr, 0), std::runtime_error);
279 }
280
TEST_F(AnfRuntimeAlgorithmTest,GetInputFormat)281 TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) {
282 auto kernel_graph = std::make_shared<KernelGraph>();
283 std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimAdd), kernel_graph->NewParameter(),
284 kernel_graph->NewParameter()};
285 auto add = kernel_graph->NewCNode(inputs);
286 MS_EXCEPTION_IF_NULL(add);
287 add->set_kernel_info(std::make_shared<KernelInfo>());
288 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
289 MS_EXCEPTION_IF_NULL(d_kernel_info);
290 KernelBuildInfoBuilder builder;
291 builder.SetInputsDeviceType({kFloat32->type_id(), kFloat16->type_id()});
292 builder.SetInputsFormat({kOpFormat_NCHW, kOpFormat_NC1HWC0});
293 d_kernel_info->set_select_kernel_build_info(builder.Build());
294 EXPECT_EQ(AnfAlgo::GetInputFormat(add, 0), kOpFormat_NCHW);
295 EXPECT_EQ(AnfAlgo::GetInputFormat(add, 1), kOpFormat_NC1HWC0);
296 EXPECT_THROW(AnfAlgo::GetInputFormat(add, 2), std::runtime_error);
297 EXPECT_THROW(AnfAlgo::GetInputFormat(nullptr, 0), std::runtime_error);
298 }
299
TEST_F(AnfRuntimeAlgorithmTest,GetPrevNodeOutputFormat)300 TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputFormat) {
301 auto kernel_graph = std::make_shared<KernelGraph>();
302 std::vector<AnfNodePtr> pre_node_inputs;
303 pre_node_inputs.push_back(NewValueNode(prim::kPrimAdd));
304 auto pre_add = kernel_graph->NewCNode(pre_node_inputs);
305 MS_EXCEPTION_IF_NULL(pre_add);
306 pre_add->set_kernel_info(std::make_shared<KernelInfo>());
307 auto d_kernel_info = dynamic_cast<KernelInfo *>(pre_add->kernel_info());
308 MS_EXCEPTION_IF_NULL(d_kernel_info);
309 KernelBuildInfoBuilder builder;
310 builder.SetOutputsDeviceType({kFloat32->type_id()});
311 builder.SetOutputsFormat({kOpFormat_NCHW});
312 d_kernel_info->set_select_kernel_build_info(builder.Build());
313 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimAdd), pre_add};
314 auto add = kernel_graph->NewCNode(inputs);
315 EXPECT_EQ(AnfAlgo::GetPrevNodeOutputFormat(add, 0), kOpFormat_NCHW);
316 EXPECT_THROW(AnfAlgo::GetPrevNodeOutputFormat(nullptr, 0), std::runtime_error);
317 // test parameter node as input
318 auto parameter_node = kernel_graph->add_parameter();
319 EXPECT_THROW(AnfAlgo::GetPrevNodeOutputFormat(parameter_node, 0), std::runtime_error);
320 }
321
TEST_F(AnfRuntimeAlgorithmTest,GetOutputInferShape)322 TEST_F(AnfRuntimeAlgorithmTest, GetOutputInferShape) {
323 auto kernel_graph = std::make_shared<KernelGraph>();
324 std::vector<int64_t> shp{2, 32, 224, 224};
325 auto none_abstract = std::make_shared<abstract::AbstractNone>();
326 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
327 AbstractBasePtrList args_spec_list{x_abstract, none_abstract, x_abstract};
328 auto tuple_abstract = std::make_shared<abstract::AbstractTuple>(args_spec_list);
329 // test value node as input
330 auto value_node = NewValueNode(prim::kPrimAdd);
331 MS_EXCEPTION_IF_NULL(value_node);
332 value_node->set_abstract(x_abstract);
333 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(value_node, 0)[1], 32);
334 EXPECT_THROW(common::AnfAlgo::GetOutputInferShape(nullptr, 0), std::runtime_error);
335 // test parameter node as input
336 auto parameter_node = kernel_graph->add_parameter();
337 MS_EXCEPTION_IF_NULL(parameter_node);
338 parameter_node->set_abstract(x_abstract);
339 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(parameter_node, 0)[2], 224);
340 // test cnode as input
341 std::vector<AnfNodePtr> inputs;
342 inputs.push_back(NewValueNode(prim::kPrimAdd));
343 auto add = kernel_graph->NewCNode(inputs);
344 MS_EXCEPTION_IF_NULL(add);
345 add->set_abstract(std::make_shared<abstract::AbstractNone>());
346 EXPECT_TRUE(common::AnfAlgo::GetOutputInferShape(add, 0).empty());
347 add->set_abstract(x_abstract);
348 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(add, 0)[3], 224);
349 EXPECT_THROW(common::AnfAlgo::GetOutputInferShape(add, 1), std::runtime_error);
350 add->set_abstract(tuple_abstract);
351 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(add, 0)[0], 2);
352 EXPECT_TRUE(common::AnfAlgo::GetOutputInferShape(add, 1).empty());
353 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(add, 2)[1], 32);
354 EXPECT_THROW(common::AnfAlgo::GetOutputInferShape(add, 3), std::runtime_error);
355 }
356
TEST_F(AnfRuntimeAlgorithmTest,GetPrevNodeOutputInferShape)357 TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputInferShape) {
358 auto kernel_graph = std::make_shared<KernelGraph>();
359 std::vector<int64_t> shp{2, 32, 224, 224};
360 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
361 // test parameter node as input
362 auto parameter_node = kernel_graph->NewParameter();
363 MS_EXCEPTION_IF_NULL(parameter_node);
364 parameter_node->set_abstract(x_abstract);
365 EXPECT_THROW(common::AnfAlgo::GetPrevNodeOutputInferShape(parameter_node, 0), std::runtime_error);
366 // test cnode as input
367 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimAdd), parameter_node};
368 auto add = kernel_graph->NewCNode(inputs);
369 EXPECT_EQ(common::AnfAlgo::GetPrevNodeOutputInferShape(add, 0)[1], 32);
370 EXPECT_THROW(common::AnfAlgo::GetPrevNodeOutputInferShape(add, 1), std::runtime_error);
371 }
372
TEST_F(AnfRuntimeAlgorithmTest,GetOutputDeviceShape)373 TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceShape) {
374 auto kernel_graph = std::make_shared<KernelGraph>();
375 std::vector<int64_t> shp{2, 32, 224, 224};
376 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
377 AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract};
378 args_spec_list.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{1, 2, 3, 4}));
379 auto tuple_abstract = std::make_shared<abstract::AbstractTuple>(args_spec_list);
380 // test cnode as input
381 std::vector<AnfNodePtr> inputs;
382 inputs.push_back(NewValueNode(prim::kPrimAdd));
383 auto add = kernel_graph->NewCNode(inputs);
384 MS_EXCEPTION_IF_NULL(add);
385 add->set_abstract(tuple_abstract);
386 add->set_kernel_info(std::make_shared<KernelInfo>());
387 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
388 MS_EXCEPTION_IF_NULL(d_kernel_info);
389 KernelBuildInfoBuilder builder;
390 builder.SetOutputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_NZ});
391 builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
392 d_kernel_info->set_select_kernel_build_info(builder.Build());
393 EXPECT_EQ(AnfAlgo::GetOutputDeviceShape(add, 0)[2], 224);
394 EXPECT_EQ(AnfAlgo::GetOutputDeviceShape(add, 1)[0], 2);
395 ShapeVector expect_shape{2, 224, 224, 32};
396 EXPECT_EQ(AnfAlgo::GetOutputDeviceShape(add, 2), expect_shape);
397 ShapeVector nz_expect_shape{1, 2, 1, 1, 16, 16};
398 EXPECT_EQ(AnfAlgo::GetOutputDeviceShape(add, 3), nz_expect_shape);
399 }
400
TEST_F(AnfRuntimeAlgorithmTest,GetInputDeviceShape)401 TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) {
402 auto kernel_graph = std::make_shared<KernelGraph>();
403 std::vector<int64_t> shp{2, 32, 224, 224};
404 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
405 auto parameter_one = kernel_graph->NewParameter();
406 MS_EXCEPTION_IF_NULL(parameter_one);
407 parameter_one->set_abstract(x_abstract);
408 auto parameter_two = kernel_graph->NewParameter();
409 MS_EXCEPTION_IF_NULL(parameter_two);
410 parameter_two->set_abstract(x_abstract);
411 auto parameter_third = kernel_graph->NewParameter();
412 MS_EXCEPTION_IF_NULL(parameter_third);
413 parameter_third->set_abstract(x_abstract);
414 // test cnode as input
415 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimAdd), parameter_one, parameter_two, parameter_third};
416 auto add = kernel_graph->NewCNode(inputs);
417 MS_EXCEPTION_IF_NULL(add);
418 add->set_kernel_info(std::make_shared<KernelInfo>());
419 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
420 MS_EXCEPTION_IF_NULL(d_kernel_info);
421 KernelBuildInfoBuilder builder;
422 builder.SetInputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC});
423 builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
424 d_kernel_info->set_select_kernel_build_info(builder.Build());
425 EXPECT_EQ(AnfAlgo::GetInputDeviceShape(add, 0)[2], 224);
426 EXPECT_EQ(AnfAlgo::GetInputDeviceShape(add, 1)[1], 32);
427 ShapeVector expect_shape{2, 224, 224, 32};
428 EXPECT_EQ(AnfAlgo::GetInputDeviceShape(add, 2), expect_shape);
429 EXPECT_THROW(common::AnfAlgo::GetPrevNodeOutputInferShape(nullptr, 0), std::runtime_error);
430 }
431
TEST_F(AnfRuntimeAlgorithmTest,GetOutputInferDataTypeTest)432 TEST_F(AnfRuntimeAlgorithmTest, GetOutputInferDataTypeTest) {
433 auto kernel_graph = std::make_shared<KernelGraph>();
434 std::vector<AnfNodePtr> inputs;
435 inputs.push_back(NewValueNode(prim::kPrimBatchNorm));
436 auto bn = kernel_graph->NewCNode(inputs);
437 MS_EXCEPTION_IF_NULL(bn);
438 std::vector<int64_t> shp{2, 32, 224, 224};
439 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
440 AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract};
441 bn->set_abstract(std::make_shared<abstract::AbstractTuple>(args_spec_list));
442 EXPECT_EQ(common::AnfAlgo::GetOutputInferDataType(bn, 0), kFloat32->type_id());
443 EXPECT_EQ(common::AnfAlgo::GetOutputInferDataType(bn, 4), kFloat32->type_id());
444 EXPECT_THROW(common::AnfAlgo::GetOutputInferDataType(bn, 5), std::runtime_error);
445 }
446
TEST_F(AnfRuntimeAlgorithmTest,GetPrevNodeOutputInferDataType)447 TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputInferDataType) {
448 auto kernel_graph = std::make_shared<KernelGraph>();
449 std::vector<AnfNodePtr> pre_node_inputs;
450 pre_node_inputs.push_back(NewValueNode(prim::kPrimAdd));
451 auto pre_add = kernel_graph->NewCNode(pre_node_inputs);
452 MS_EXCEPTION_IF_NULL(pre_add);
453 std::vector<int64_t> shp{2, 32, 224, 224};
454 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
455 pre_add->set_abstract(x_abstract);
456 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimAdd), pre_add};
457 auto add = kernel_graph->NewCNode(inputs);
458 EXPECT_EQ(common::AnfAlgo::GetPrevNodeOutputInferDataType(add, 0), kFloat32->type_id());
459 EXPECT_THROW(common::AnfAlgo::GetPrevNodeOutputInferDataType(add, 1), std::runtime_error);
460 EXPECT_THROW(common::AnfAlgo::GetPrevNodeOutputInferDataType(nullptr, 0), std::runtime_error);
461 // test parameter as input
462 auto parameter_node = kernel_graph->add_parameter();
463 EXPECT_THROW(common::AnfAlgo::GetPrevNodeOutputInferDataType(parameter_node, 0), std::runtime_error);
464 }
465
TEST_F(AnfRuntimeAlgorithmTest,GetOutputDeviceDataTypeTest)466 TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) {
467 auto kernel_graph = std::make_shared<KernelGraph>();
468 std::vector<AnfNodePtr> inputs;
469 inputs.push_back(NewValueNode(prim::kPrimAdd));
470 auto add = kernel_graph->NewCNode(inputs);
471 MS_EXCEPTION_IF_NULL(add);
472 add->set_kernel_info(std::make_shared<KernelInfo>());
473 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
474 MS_EXCEPTION_IF_NULL(d_kernel_info);
475 KernelBuildInfoBuilder builder;
476 builder.SetOutputsDeviceType({kFloat32->type_id()});
477 builder.SetOutputsFormat({kOpFormat_NCHW});
478 d_kernel_info->set_select_kernel_build_info(builder.Build());
479 EXPECT_EQ(AnfAlgo::GetOutputDeviceDataType(add, 0), kFloat32->type_id());
480 EXPECT_THROW(AnfAlgo::GetOutputDeviceDataType(add, 1), std::runtime_error);
481 }
482
TEST_F(AnfRuntimeAlgorithmTest,GetInputDeviceDataTypeTest)483 TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) {
484 auto kernel_graph = std::make_shared<KernelGraph>();
485 std::vector<AnfNodePtr> inputs = {NewValueNode(prim::kPrimAdd), kernel_graph->NewParameter(),
486 kernel_graph->NewParameter()};
487 auto add = kernel_graph->NewCNode(inputs);
488 MS_EXCEPTION_IF_NULL(add);
489 add->set_kernel_info(std::make_shared<KernelInfo>());
490 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
491 MS_EXCEPTION_IF_NULL(d_kernel_info);
492 KernelBuildInfoBuilder builder;
493 builder.SetInputsDeviceType({kFloat32->type_id(), kFloat16->type_id()});
494 builder.SetInputsFormat({kOpFormat_NCHW, kOpFormat_NC1HWC0});
495 d_kernel_info->set_select_kernel_build_info(builder.Build());
496 EXPECT_EQ(AnfAlgo::GetInputDeviceDataType(add, 0), kFloat32->type_id());
497 EXPECT_EQ(AnfAlgo::GetInputDeviceDataType(add, 1), kFloat16->type_id());
498 EXPECT_THROW(AnfAlgo::GetInputDeviceDataType(add, 2), std::runtime_error);
499 }
500
TEST_F(AnfRuntimeAlgorithmTest,GetPrevNodeOutputDeviceDataType)501 TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputDeviceDataType) {
502 auto kernel_graph = std::make_shared<KernelGraph>();
503 std::vector<AnfNodePtr> pre_add_inputs;
504 pre_add_inputs.push_back(NewValueNode(prim::kPrimAdd));
505 auto pre_add = kernel_graph->NewCNode(pre_add_inputs);
506 MS_EXCEPTION_IF_NULL(pre_add);
507 pre_add->set_kernel_info(std::make_shared<KernelInfo>());
508 auto d_kernel_info = dynamic_cast<KernelInfo *>(pre_add->kernel_info());
509 MS_EXCEPTION_IF_NULL(d_kernel_info);
510 KernelBuildInfoBuilder builder;
511 builder.SetOutputsDeviceType({kFloat32->type_id()});
512 d_kernel_info->set_select_kernel_build_info(builder.Build());
513 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimAdd), pre_add};
514 auto add = kernel_graph->NewCNode(inputs);
515 EXPECT_EQ(AnfAlgo::GetPrevNodeOutputDeviceDataType(add, 0), kFloat32->type_id());
516 EXPECT_THROW(AnfAlgo::GetPrevNodeOutputDeviceDataType(add, 1), std::runtime_error);
517 // test parameter as input
518 auto parameter_node = kernel_graph->add_parameter();
519 EXPECT_THROW(AnfAlgo::GetPrevNodeOutputDeviceDataType(parameter_node, 0), std::runtime_error);
520 }
521
TEST_F(AnfRuntimeAlgorithmTest,GetOutputAddr)522 TEST_F(AnfRuntimeAlgorithmTest, GetOutputAddr) {
523 auto kernel_graph = std::make_shared<KernelGraph>();
524 std::vector<AnfNodePtr> inputs;
525 inputs.push_back(NewValueNode(prim::kPrimAdd));
526 auto add = kernel_graph->NewCNode(inputs);
527 MS_EXCEPTION_IF_NULL(add);
528 add->set_kernel_info(std::make_shared<KernelInfo>());
529 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
530 MS_EXCEPTION_IF_NULL(d_kernel_info);
531 int *addr = nullptr;
532 auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1);
533 d_kernel_info->SetOutputAddr(device_address, 0);
534 EXPECT_EQ(AnfAlgo::GetOutputAddr(add, 0), device_address.get());
535 }
536
TEST_F(AnfRuntimeAlgorithmTest,GetPrevNodeOutputAddr)537 TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputAddr) {
538 auto kernel_graph = std::make_shared<KernelGraph>();
539 std::vector<AnfNodePtr> pre_add_inputs;
540 pre_add_inputs.push_back(NewValueNode(prim::kPrimAdd));
541 auto pre_add = kernel_graph->NewCNode(pre_add_inputs);
542 MS_EXCEPTION_IF_NULL(pre_add);
543 pre_add->set_kernel_info(std::make_shared<KernelInfo>());
544 auto d_kernel_info = dynamic_cast<KernelInfo *>(pre_add->kernel_info());
545 MS_EXCEPTION_IF_NULL(d_kernel_info);
546 int *addr = nullptr;
547 auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1);
548 d_kernel_info->SetOutputAddr(device_address, 0);
549 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimAdd), pre_add};
550 auto add = kernel_graph->NewCNode(inputs);
551 EXPECT_EQ(AnfAlgo::GetPrevNodeOutputAddr(add, 0), device_address.get());
552 EXPECT_THROW(AnfAlgo::GetPrevNodeOutputAddr(add, 1), std::runtime_error);
553 // test parameter as input
554 auto parameter_node = kernel_graph->add_parameter();
555 EXPECT_THROW(AnfAlgo::GetPrevNodeOutputAddr(parameter_node, 0), std::runtime_error);
556 }
557
TEST_F(AnfRuntimeAlgorithmTest,SetOutputAddr)558 TEST_F(AnfRuntimeAlgorithmTest, SetOutputAddr) {
559 auto kernel_graph = std::make_shared<KernelGraph>();
560 std::vector<AnfNodePtr> inputs;
561 inputs.push_back(NewValueNode(prim::kPrimAdd));
562 auto add = kernel_graph->NewCNode(inputs);
563 int *addr = nullptr;
564 auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1);
565 EXPECT_THROW(AnfAlgo::SetOutputAddr(device_address, 0, nullptr), std::runtime_error);
566 AnfAlgo::SetOutputAddr(device_address, 0, add.get());
567 EXPECT_EQ(AnfAlgo::GetOutputAddr(add, 0), device_address.get());
568 }
569
TEST_F(AnfRuntimeAlgorithmTest,GetWorkspaceAddr)570 TEST_F(AnfRuntimeAlgorithmTest, GetWorkspaceAddr) {
571 auto kernel_graph = std::make_shared<KernelGraph>();
572 std::vector<AnfNodePtr> inputs;
573 inputs.push_back(NewValueNode(prim::kPrimAdd));
574 auto add = kernel_graph->NewCNode(inputs);
575 MS_EXCEPTION_IF_NULL(add);
576 add->set_kernel_info(std::make_shared<KernelInfo>());
577 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
578 MS_EXCEPTION_IF_NULL(d_kernel_info);
579 int *addr = nullptr;
580 auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1);
581 d_kernel_info->SetWorkspaceAddr(device_address, 0);
582 EXPECT_EQ(AnfAlgo::GetWorkspaceAddr(add, 0), device_address.get());
583 }
584
TEST_F(AnfRuntimeAlgorithmTest,SetWorkspaceAddr)585 TEST_F(AnfRuntimeAlgorithmTest, SetWorkspaceAddr) {
586 auto kernel_graph = std::make_shared<KernelGraph>();
587 std::vector<AnfNodePtr> inputs;
588 inputs.push_back(NewValueNode(prim::kPrimAdd));
589 auto add = kernel_graph->NewCNode(inputs);
590 int *addr = nullptr;
591 auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1);
592 EXPECT_THROW(AnfAlgo::SetWorkspaceAddr(device_address, 0, nullptr), std::runtime_error);
593 AnfAlgo::SetWorkspaceAddr(device_address, 0, add.get());
594 EXPECT_EQ(AnfAlgo::GetWorkspaceAddr(add, 0), device_address.get());
595 }
596
TEST_F(AnfRuntimeAlgorithmTest,SetOutputInferTypeAndShape)597 TEST_F(AnfRuntimeAlgorithmTest, SetOutputInferTypeAndShape) {
598 auto kernel_graph = std::make_shared<KernelGraph>();
599 std::vector<AnfNodePtr> inputs;
600 inputs.push_back(NewValueNode(prim::kPrimAdd));
601 auto add = kernel_graph->NewCNode(inputs);
602 // set none abstract
603 std::vector<TypeId> none_types = {};
604 std::vector<ShapeVector> none_shapes = {};
605 EXPECT_THROW(common::AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, nullptr), std::runtime_error);
606 common::AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, add.get());
607 EXPECT_EQ((*add->abstract()), abstract::AbstractNone());
608 // set single input
609 std::vector<TypeId> single_types = {kFloat32->type_id()};
610 std::vector<ShapeVector> single_shapes = {{2, 32, 224, 224}};
611 EXPECT_THROW(common::AnfAlgo::SetOutputInferTypeAndShape(none_types, single_shapes, add.get()), std::runtime_error);
612 common::AnfAlgo::SetOutputInferTypeAndShape(single_types, single_shapes, add.get());
613 EXPECT_EQ(common::AnfAlgo::GetOutputInferDataType(add, 0), kFloat32->type_id());
614 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(add, 0).size(), 4);
615 // set multiple input
616 std::vector<TypeId> mutiple_types = {kFloat16->type_id(), kFloat32->type_id(), kFloat64->type_id()};
617 std::vector<ShapeVector> mutiple_shapes = {{2, 32, 224, 224}, {2, 32, 224, 224}, {2, 32, 224, 224}};
618 common::AnfAlgo::SetOutputInferTypeAndShape(mutiple_types, mutiple_shapes, add.get());
619 EXPECT_EQ(common::AnfAlgo::GetOutputInferDataType(add, 0), kFloat16->type_id());
620 EXPECT_EQ(common::AnfAlgo::GetOutputInferDataType(add, 1), kFloat32->type_id());
621 EXPECT_EQ(common::AnfAlgo::GetOutputInferDataType(add, 2), kFloat64->type_id());
622 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(add, 0).size(), 4);
623 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(add, 1).size(), 4);
624 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(add, 2).size(), 4);
625 }
626
TEST_F(AnfRuntimeAlgorithmTest,CopyAbstract)627 TEST_F(AnfRuntimeAlgorithmTest, CopyAbstract) {
628 auto kernel_graph = std::make_shared<KernelGraph>();
629 std::vector<AnfNodePtr> first_inputs;
630 first_inputs.push_back(NewValueNode(prim::kPrimAdd));
631 auto first_add = kernel_graph->NewCNode(first_inputs);
632 // set single input
633 std::vector<TypeId> single_types = {kFloat32->type_id()};
634 std::vector<ShapeVector> single_shapes = {{2, 32, 224, 224}};
635 common::AnfAlgo::SetOutputInferTypeAndShape(single_types, single_shapes, first_add.get());
636 // set multiple input
637 std::vector<AnfNodePtr> second_inputs;
638 second_inputs.push_back(NewValueNode(prim::kPrimAdd));
639 auto second_add = kernel_graph->NewCNode(second_inputs);
640 std::vector<TypeId> mutiple_types = {kFloat16->type_id(), kFloat32->type_id(), kFloat64->type_id()};
641 std::vector<ShapeVector> mutiple_shapes = {{2, 32, 224, 224}, {2, 32, 224, 224}, {2, 32, 224, 224}};
642 common::AnfAlgo::SetOutputInferTypeAndShape(mutiple_types, mutiple_shapes, second_add.get());
643 common::AnfAlgo::CopyAbstract(second_add, first_add.get());
644 EXPECT_EQ(common::AnfAlgo::GetOutputInferDataType(first_add, 0), kFloat16->type_id());
645 EXPECT_EQ(common::AnfAlgo::GetOutputInferDataType(first_add, 1), kFloat32->type_id());
646 EXPECT_EQ(common::AnfAlgo::GetOutputInferDataType(first_add, 2), kFloat64->type_id());
647 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(first_add, 0).size(), 4);
648 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(first_add, 1).size(), 4);
649 EXPECT_EQ(common::AnfAlgo::GetOutputInferShape(first_add, 2).size(), 4);
650 }
651
TEST_F(AnfRuntimeAlgorithmTest,GetKernelType)652 TEST_F(AnfRuntimeAlgorithmTest, GetKernelType) {
653 auto kernel_graph = std::make_shared<KernelGraph>();
654 std::vector<AnfNodePtr> inputs;
655 inputs.push_back(NewValueNode(prim::kPrimAdd));
656 auto add = kernel_graph->NewCNode(inputs);
657 MS_EXCEPTION_IF_NULL(add);
658 add->set_kernel_info(std::make_shared<KernelInfo>());
659 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
660 MS_EXCEPTION_IF_NULL(d_kernel_info);
661 KernelBuildInfoBuilder builder;
662 builder.SetKernelType(AKG_KERNEL);
663 d_kernel_info->set_select_kernel_build_info(builder.Build());
664 EXPECT_EQ(AnfAlgo::GetKernelType(add), AKG_KERNEL);
665 EXPECT_THROW(AnfAlgo::GetKernelType(nullptr), std::runtime_error);
666 }
667
TEST_F(AnfRuntimeAlgorithmTest,GetProcessor)668 TEST_F(AnfRuntimeAlgorithmTest, GetProcessor) {
669 auto kernel_graph = std::make_shared<KernelGraph>();
670 std::vector<AnfNodePtr> inputs;
671 inputs.push_back(NewValueNode(prim::kPrimAdd));
672 auto add = kernel_graph->NewCNode(inputs);
673 MS_EXCEPTION_IF_NULL(add);
674 add->set_kernel_info(std::make_shared<KernelInfo>());
675 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
676 MS_EXCEPTION_IF_NULL(d_kernel_info);
677 KernelBuildInfoBuilder builder;
678 builder.SetProcessor(kernel::AICORE);
679 d_kernel_info->set_select_kernel_build_info(builder.Build());
680 EXPECT_EQ(AnfAlgo::GetProcessor(add), kernel::AICORE);
681 EXPECT_THROW(AnfAlgo::GetProcessor(nullptr), std::runtime_error);
682 }
683
TEST_F(AnfRuntimeAlgorithmTest,GetFusionType)684 TEST_F(AnfRuntimeAlgorithmTest, GetFusionType) {
685 auto kernel_graph = std::make_shared<KernelGraph>();
686 std::vector<AnfNodePtr> inputs;
687 inputs.push_back(NewValueNode(prim::kPrimAdd));
688 auto add = kernel_graph->NewCNode(inputs);
689 MS_EXCEPTION_IF_NULL(add);
690 add->set_kernel_info(std::make_shared<KernelInfo>());
691 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
692 MS_EXCEPTION_IF_NULL(d_kernel_info);
693 KernelBuildInfoBuilder builder;
694 builder.SetFusionType(kPatternConvolution);
695 d_kernel_info->set_select_kernel_build_info(builder.Build());
696 EXPECT_EQ(AnfAlgo::GetFusionType(add), kPatternConvolution);
697 EXPECT_THROW(AnfAlgo::GetFusionType(nullptr), std::runtime_error);
698 }
699
TEST_F(AnfRuntimeAlgorithmTest,SetSelectKernelBuildInfo)700 TEST_F(AnfRuntimeAlgorithmTest, SetSelectKernelBuildInfo) {
701 auto kernel_graph = std::make_shared<KernelGraph>();
702 std::vector<AnfNodePtr> inputs;
703 inputs.push_back(NewValueNode(prim::kPrimAdd));
704 auto add = kernel_graph->NewCNode(inputs);
705 std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
706 builder->SetFusionType(kPatternConvolution);
707 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), add.get());
708 EXPECT_THROW(AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), nullptr), std::runtime_error);
709 EXPECT_EQ(AnfAlgo::GetFusionType(add), kPatternConvolution);
710 }
711
TEST_F(AnfRuntimeAlgorithmTest,GetKernelMod)712 TEST_F(AnfRuntimeAlgorithmTest, GetKernelMod) {
713 auto kernel_graph = std::make_shared<KernelGraph>();
714 std::vector<AnfNodePtr> inputs;
715 inputs.push_back(NewValueNode(prim::kPrimAdd));
716 auto add = kernel_graph->NewCNode(inputs);
717 MS_EXCEPTION_IF_NULL(add);
718 add->set_kernel_info(std::make_shared<KernelInfo>());
719 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
720 MS_EXCEPTION_IF_NULL(d_kernel_info);
721 d_kernel_info->set_kernel_mod(nullptr);
722 EXPECT_EQ(AnfAlgo::GetKernelMod(add), nullptr);
723 EXPECT_THROW(AnfAlgo::GetKernelMod(nullptr), std::runtime_error);
724 }
725
TEST_F(AnfRuntimeAlgorithmTest,SetKernelMod)726 TEST_F(AnfRuntimeAlgorithmTest, SetKernelMod) {
727 auto kernel_graph = std::make_shared<KernelGraph>();
728 std::vector<AnfNodePtr> inputs;
729 inputs.push_back(NewValueNode(prim::kPrimAdd));
730 auto add = kernel_graph->NewCNode(inputs);
731 AnfAlgo::SetKernelMod(nullptr, add.get());
732 EXPECT_THROW(AnfAlgo::SetKernelMod(nullptr, nullptr), std::runtime_error);
733 EXPECT_EQ(AnfAlgo::GetKernelMod(add), nullptr);
734 }
735
TEST_F(AnfRuntimeAlgorithmTest,IsRealKernel)736 TEST_F(AnfRuntimeAlgorithmTest, IsRealKernel) {
737 auto kernel_graph = std::make_shared<KernelGraph>();
738 // test value node as input
739 auto value_node = NewValueNode(prim::kPrimAdd);
740 EXPECT_TRUE(AnfUtils::IsRealKernel(value_node));
741 EXPECT_THROW(AnfUtils::IsRealKernel(nullptr), std::runtime_error);
742 // test parameter as input
743 auto parameter_node = kernel_graph->add_parameter();
744 EXPECT_TRUE(AnfUtils::IsRealKernel(parameter_node));
745 // test add as input
746 std::vector<AnfNodePtr> inputs;
747 inputs.push_back(NewValueNode(prim::kPrimAdd));
748 auto add = kernel_graph->NewCNode(inputs);
749 EXPECT_TRUE(AnfUtils::IsRealKernel(add));
750 // test Depend as input
751 inputs.clear();
752 inputs.push_back(NewValueNode(prim::kPrimDepend));
753 auto depend_node = kernel_graph->NewCNode(inputs);
754 EXPECT_FALSE(AnfUtils::IsRealKernel(depend_node));
755 }
756
TEST_F(AnfRuntimeAlgorithmTest,IsRealCNodeKernel)757 TEST_F(AnfRuntimeAlgorithmTest, IsRealCNodeKernel) {
758 auto kernel_graph = std::make_shared<KernelGraph>();
759 // test value node as input
760 auto value_node = NewValueNode(prim::kPrimAdd);
761 EXPECT_FALSE(AnfUtils::IsRealCNodeKernel(value_node));
762 EXPECT_THROW(AnfUtils::IsRealCNodeKernel(nullptr), std::runtime_error);
763 // test parameter as input
764 auto parameter_node = kernel_graph->add_parameter();
765 EXPECT_FALSE(AnfUtils::IsRealCNodeKernel(parameter_node));
766 // test add as input
767 std::vector<AnfNodePtr> inputs;
768 inputs.push_back(NewValueNode(prim::kPrimAdd));
769 auto add = kernel_graph->NewCNode(inputs);
770 EXPECT_TRUE(AnfUtils::IsRealCNodeKernel(add));
771 // test ImageSummary as input
772 inputs.clear();
773 inputs.push_back(NewValueNode(prim::kPrimDepend));
774 auto depend = kernel_graph->NewCNode(inputs);
775 EXPECT_FALSE(AnfUtils::IsRealCNodeKernel(depend));
776 }
777
TEST_F(AnfRuntimeAlgorithmTest,IsParameterWeight)778 TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) {
779 auto kernel_graph = std::make_shared<KernelGraph>();
780 auto parameter_node = kernel_graph->add_parameter();
781 MS_EXCEPTION_IF_NULL(parameter_node);
782 auto param_value_new = std::make_shared<tensor::Tensor>(int64_t(0), kInt32);
783 parameter_node->set_default_param(param_value_new);
784 EXPECT_TRUE(common::AnfAlgo::IsParameterWeight(parameter_node));
785 EXPECT_THROW(common::AnfAlgo::IsParameterWeight(nullptr), std::runtime_error);
786 }
787
TEST_F(AnfRuntimeAlgorithmTest,GetStreamId)788 TEST_F(AnfRuntimeAlgorithmTest, GetStreamId) {
789 auto kernel_graph = std::make_shared<KernelGraph>();
790 std::vector<AnfNodePtr> inputs;
791 inputs.push_back(NewValueNode(prim::kPrimAdd));
792 auto add = kernel_graph->NewCNode(inputs);
793 MS_EXCEPTION_IF_NULL(add);
794 add->set_kernel_info(std::make_shared<KernelInfo>());
795 auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
796 MS_EXCEPTION_IF_NULL(d_kernel_info);
797 d_kernel_info->set_stream_id(0);
798 EXPECT_EQ(AnfAlgo::GetStreamId(add), 0);
799 EXPECT_THROW(AnfAlgo::GetStreamId(nullptr), std::runtime_error);
800 }
801
TEST_F(AnfRuntimeAlgorithmTest,SetStreamId)802 TEST_F(AnfRuntimeAlgorithmTest, SetStreamId) {
803 auto kernel_graph = std::make_shared<KernelGraph>();
804 std::vector<AnfNodePtr> inputs;
805 inputs.push_back(NewValueNode(prim::kPrimAdd));
806 auto add = kernel_graph->NewCNode(inputs);
807 AnfAlgo::SetStreamId(0, add.get());
808 EXPECT_THROW(AnfAlgo::SetStreamId(0, nullptr), std::runtime_error);
809 EXPECT_EQ(AnfAlgo::GetStreamId(add), 0);
810 }
811
812 /// Feature: fix get item to tuple in tuple.
813 /// Description: check the function between tuple and getitem.
814 /// Expectation: As expected.
TEST_F(AnfRuntimeAlgorithmTest,GetAllOutputWithIndex)815 TEST_F(AnfRuntimeAlgorithmTest, GetAllOutputWithIndex) {
816 std::vector<int64_t> shp{1};
817 auto graph = std::make_shared<FuncGraph>();
818 std::vector<KernelWithIndex> results;
819
820 auto value_node_0 = NewValueNode(static_cast<int64_t>(0));
821 auto value_node_1 = NewValueNode(static_cast<int64_t>(1));
822 auto value_node_2 = NewValueNode(static_cast<int64_t>(2));
823 auto value_node_3 = NewValueNode(static_cast<int64_t>(3));
824 auto abstract_single = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
825 AbstractBasePtrList abstruct_list_1{abstract_single, abstract_single};
826 auto abstruct_tuple = std::make_shared<abstract::AbstractTuple>(abstruct_list_1);
827 AbstractBasePtrList abstruct_list_2{abstract_single, abstruct_tuple, abstract_single};
828 auto abstract_tuple_in_tuple = std::make_shared<abstract::AbstractTuple>(abstruct_list_2);
829 value_node_1->set_abstract(abstract_single);
830
831 // Make tuple node: (1, 2)
832 std::vector<AnfNodePtr> tuple_input1{NewValueNode(prim::kPrimMakeTuple), value_node_1, value_node_2};
833 auto tuple_1 = graph->NewCNode(tuple_input1);
834 tuple_1->set_abstract(abstruct_tuple);
835 results = common::AnfAlgo::GetAllOutputWithIndex(tuple_1);
836 EXPECT_EQ(results.size(), 2);
837
838 // Make tuple node: (0, (1, 2), 3)
839 std::vector<AnfNodePtr> tuple_input2{NewValueNode(prim::kPrimMakeTuple), value_node_0, tuple_1, value_node_3};
840 auto tuple_2 = graph->NewCNode(tuple_input2);
841 tuple_2->set_abstract(abstract_tuple_in_tuple);
842 results = common::AnfAlgo::GetAllOutputWithIndex(tuple_2);
843 EXPECT_EQ(results.size(), 4);
844
845 // Get item node: (0, (1, 2), 3)[1] = (1, 2)
846 std::vector<AnfNodePtr> getitem_input1{NewValueNode(prim::kPrimTupleGetItem), tuple_2, value_node_1};
847 auto getitem_1 = graph->NewCNode(getitem_input1);
848 getitem_1->set_abstract(abstruct_tuple);
849 results = common::AnfAlgo::GetAllOutputWithIndex(getitem_1);
850 EXPECT_EQ(results.size(), 2);
851 EXPECT_EQ(GetValue<int64_t>(results[0].first->cast<ValueNodePtr>()->value()), 1);
852 EXPECT_EQ(GetValue<int64_t>(results[1].first->cast<ValueNodePtr>()->value()), 2);
853
854 // Get item node: (0, (1, 2), 3)[1][1] = (1, 2)[1] = (1)
855 std::vector<AnfNodePtr> getitem_input2{NewValueNode(prim::kPrimTupleGetItem), getitem_1, value_node_1};
856 auto getitem_2 = graph->NewCNode(getitem_input2);
857 getitem_2->set_abstract(abstract_single);
858 results = common::AnfAlgo::GetAllOutputWithIndex(getitem_2);
859 EXPECT_EQ(results.size(), 1);
860
861 // Call node with output construct: (x, (x, x), x)
862 std::vector<AnfNodePtr> call_input{NewValueNode(graph)};
863 auto call = graph->NewCNode(call_input);
864 call->set_abstract(abstract_tuple_in_tuple);
865 results = common::AnfAlgo::GetAllOutputWithIndex(call);
866 EXPECT_EQ(results.size(), 4);
867
868 // Get item node: (x, (x, x), x)[1] = (x, x)
869 std::vector<AnfNodePtr> getitem_input3{NewValueNode(prim::kPrimTupleGetItem), call, value_node_1};
870 auto getitem_3 = graph->NewCNode(getitem_input3);
871 getitem_3->set_abstract(abstruct_tuple);
872 results = common::AnfAlgo::GetAllOutputWithIndex(getitem_3);
873 EXPECT_EQ(results.size(), 2);
874 EXPECT_EQ(results[0].second, 1);
875 EXPECT_EQ(results[1].second, 2);
876
877 // Get item node: (0, (1, 2), 3)[2] = (3)
878 std::vector<AnfNodePtr> getitem_input4{NewValueNode(prim::kPrimTupleGetItem), tuple_2, value_node_2};
879 auto getitem_4 = graph->NewCNode(getitem_input4);
880 getitem_4->set_abstract(abstract_single);
881 results = common::AnfAlgo::GetAllOutputWithIndex(getitem_4);
882 EXPECT_EQ(results.size(), 1);
883 EXPECT_EQ(GetValue<int64_t>(results[0].first->cast<ValueNodePtr>()->value()), 3);
884
885 // Get item node: (x, (x, x), x)[2] = (x)
886 std::vector<AnfNodePtr> getitem_input5{NewValueNode(prim::kPrimTupleGetItem), call, value_node_2};
887 auto getitem_5 = graph->NewCNode(getitem_input5);
888 getitem_5->set_abstract(abstract_single);
889 results = common::AnfAlgo::GetAllOutputWithIndex(getitem_5);
890 EXPECT_EQ(results.size(), 1);
891 EXPECT_EQ(results[0].second, 3);
892
893 // Make tuple node: (3, (x, (x, x), x)[2], 1) = (3, x, 1)
894 std::vector<AnfNodePtr> tuple_input3{NewValueNode(prim::kPrimMakeTuple), value_node_3, getitem_5, value_node_1};
895 auto tuple_3 = graph->NewCNode(tuple_input3);
896 tuple_3->set_abstract(abstract_tuple_in_tuple);
897 results = common::AnfAlgo::GetAllOutputWithIndex(tuple_3);
898 EXPECT_EQ(results.size(), 3);
899
900 // Get item node: (3, (x, (x, x), x)[2], 1)[1] = (3, x, 1)[1] = (x)
901 std::vector<AnfNodePtr> getitem_input6{NewValueNode(prim::kPrimTupleGetItem), tuple_3, value_node_1};
902 auto getitem_6 = graph->NewCNode(getitem_input5);
903 getitem_6->set_abstract(abstract_single);
904 results = common::AnfAlgo::GetAllOutputWithIndex(getitem_6);
905 EXPECT_EQ(results.size(), 1);
906 EXPECT_EQ(results[0].second, 3);
907 }
908 } // namespace session
909 } // namespace mindspore
910