• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/common_test.h"
18 #include "frontend/operator/ops.h"
19 #include "backend/session/ascend_session.h"
20 #include "backend/session/kernel_graph.h"
21 #include "backend/session/anf_runtime_algorithm.h"
22 #include "utils/utils.h"
23 
24 namespace mindspore {
25 namespace session {
26 
27 class SessionBasicTest : public UT::Common {
28  public:
29   SessionBasicTest() = default;
30   void SetUp() override {}
31   void TearDown() override {}
32 };
33 
34 TEST_F(SessionBasicTest, ConstructKernelGraph) {
35   /*
36    * define kernel graph:
37    *     x ----- y
38    *         add ----- z
39    *               mul
40    *              return
41    */
42   auto anf_graph = std::make_shared<FuncGraph>();
43   std::vector<int64_t> shape = {2, 32, 224, 224};
44   auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape);
45   EXPECT_NE(abstract, nullptr);
46 
47   auto original_x_parameter = anf_graph->add_parameter();
48   EXPECT_NE(original_x_parameter, nullptr);
49   original_x_parameter->set_name("original_x_parameter");
50   original_x_parameter->set_abstract(abstract);
51   auto original_y_parameter = anf_graph->add_parameter();
52   EXPECT_NE(original_y_parameter, nullptr);
53   original_y_parameter->set_name("original_y_parameter");
54   original_y_parameter->set_abstract(abstract);
55   std::vector<AnfNodePtr> add_inputs = {NewValueNode(prim::kPrimAdd), original_x_parameter, original_y_parameter};
56   auto original_add = anf_graph->NewCNode(add_inputs);
57   EXPECT_NE(original_add, nullptr);
58   original_add->set_abstract(abstract);
59 
60   auto original_z_parameter = anf_graph->add_parameter();
61   EXPECT_NE(original_z_parameter, nullptr);
62   original_z_parameter->set_name("original_z_parameter");
63   original_z_parameter->set_abstract(abstract);
64   std::vector<AnfNodePtr> mul_inputs = {NewValueNode(prim::kPrimMul), original_add, original_z_parameter};
65   auto original_mul = anf_graph->NewCNode(mul_inputs);
66   EXPECT_NE(original_mul, nullptr);
67   original_mul->set_abstract(abstract);
68 
69   std::vector<AnfNodePtr> lst = {original_add, original_mul};
70   std::vector<AnfNodePtr> outputs = {original_mul};
71   session::SessionPtr sess = std::make_shared<session::AscendSession>();
72   sess->Init(0);
73   auto kernel_graph = sess->ConstructKernelGraph(lst, outputs);
74   EXPECT_NE(kernel_graph, nullptr);
75 
76   auto inputs = kernel_graph->inputs();
77   EXPECT_EQ(inputs.size(), 3);
78   auto first_input = inputs[0]->cast<ParameterPtr>();
79   EXPECT_NE(first_input, nullptr);
80   EXPECT_EQ(first_input->name(), "original_x_parameter");
81   auto second_input = inputs[1]->cast<ParameterPtr>();
82   EXPECT_NE(second_input, nullptr);
83   EXPECT_EQ(second_input->name(), "original_y_parameter");
84   auto third_input = inputs[2]->cast<ParameterPtr>();
85   EXPECT_NE(third_input, nullptr);
86   EXPECT_EQ(third_input->name(), "original_z_parameter");
87   kernel_graph->SetExecOrderByDefault();
88   auto execution_order = kernel_graph->execution_order();
89   EXPECT_EQ(execution_order.size(), 2);
90   EXPECT_EQ(AnfAlgo::GetCNodeName(execution_order[0]), prim::kPrimAdd->name());
91   EXPECT_EQ(AnfAlgo::GetCNodeName(execution_order[1]), prim::kPrimMul->name());
92   auto new_outputs = kernel_graph->outputs();
93   EXPECT_EQ(new_outputs.size(), 1);
94   EXPECT_EQ(AnfAlgo::GetCNodeName(new_outputs[0]), prim::kPrimMul->name());
95 };
96 
97 }  // namespace session
98 }  // namespace mindspore