• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/graph_optimizer_test_framework.h"
17 #include <memory>
18 #include <string>
19 #include <vector>
20 #include <utility>
21 #include <deque>
22 #include <algorithm>
23 #include <set>
24 #include "common/common_test.h"
25 #include "include/backend/optimizer/helper.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28 
29 namespace mindspore::test {
RunPass(const FuncGraphPtr & graph,const std::vector<opt::PassPtr> & passes)30 void RunPass(const FuncGraphPtr &graph, const std::vector<opt::PassPtr> &passes) {
31   UT_CHECK_NULL(graph);
32   auto optimizer = std::make_shared<opt::GraphOptimizer>();
33   auto pm = std::make_shared<opt::PassManager>();
34   for (const auto &pass : passes) {
35     UT_CHECK_NULL(pass);
36     pm->AddPass(pass);
37   }
38   optimizer->AddPassManager(pm);
39   (void)optimizer->Optimize(graph);
40 }
41 
ConstructGraph()42 ConstructGraph::ConstructGraph() : graph_(std::make_shared<session::KernelGraph>()) { graph_->set_graph_id(0); }
43 
GetGraph() const44 const std::shared_ptr<session::KernelGraph> &ConstructGraph::GetGraph() const { return graph_; }
45 
NewInput(const std::string & name,const AbstractBasePtr & abs)46 ParameterPtr ConstructGraph::NewInput(const std::string &name, const AbstractBasePtr &abs) {
47   MS_EXCEPTION_IF_NULL(graph_);
48   auto new_param = std::make_shared<Parameter>(graph_);
49   new_param->set_name(name);
50   new_param->set_abstract(abs);
51   new_param = graph_->NewParameter(new_param);
52   return new_param;
53 }
54 
NewScalarInput(const std::string & name,const TypePtr & type)55 ParameterPtr ConstructGraph::NewScalarInput(const std::string &name, const TypePtr &type) {
56   auto abs = std::make_shared<abstract::AbstractScalar>(type);
57   return NewInput(name, abs);
58 }
59 
NewTensorInput(const std::string & name,const TypePtr & type,const ShapeVector & shape)60 ParameterPtr ConstructGraph::NewTensorInput(const std::string &name, const TypePtr &type, const ShapeVector &shape) {
61   auto abs = std::make_shared<abstract::AbstractTensor>(type, shape);
62   return NewInput(name, abs);
63 }
64 
NewTupleInput(const std::string & name,const std::vector<std::pair<TypePtr,ShapeVector>> & pairs)65 ParameterPtr ConstructGraph::NewTupleInput(const std::string &name,
66                                            const std::vector<std::pair<TypePtr, ShapeVector>> &pairs) {
67   AbstractBasePtrList list;
68   for (const auto &[type, shape] : pairs) {
69     auto abs = std::make_shared<abstract::AbstractTensor>(type, shape);
70     list.emplace_back(std::move(abs));
71   }
72   auto abs = std::make_shared<abstract::AbstractTuple>(std::move(list), nullptr);
73   return NewInput(name, abs);
74 }
75 
NewListInput(const std::string & name,const std::vector<std::pair<TypePtr,ShapeVector>> & pairs)76 ParameterPtr ConstructGraph::NewListInput(const std::string &name,
77                                           const std::vector<std::pair<TypePtr, ShapeVector>> &pairs) {
78   AbstractBasePtrList list;
79   for (const auto &[type, shape] : pairs) {
80     auto abs = std::make_shared<abstract::AbstractTensor>(type, shape);
81     list.emplace_back(std::move(abs));
82   }
83   auto abs = std::make_shared<abstract::AbstractList>(std::move(list), nullptr);
84   return NewInput(name, abs);
85 }
86 
NewValueNode(const ValuePtr & value)87 ValueNodePtr ConstructGraph::NewValueNode(const ValuePtr &value) { return graph_->NewValueNode(value); }
88 
NewCNodeWithoutInfer(const std::string & prim_name,const std::vector<AnfNodePtr> & inputs,const mindspore::HashMap<std::string,ValuePtr> & attrs)89 CNodePtr ConstructGraph::NewCNodeWithoutInfer(const std::string &prim_name, const std::vector<AnfNodePtr> &inputs,
90                                               const mindspore::HashMap<std::string, ValuePtr> &attrs) {
91   MS_EXCEPTION_IF_NULL(graph_);
92   auto prim = std::make_shared<Primitive>(prim_name);
93   prim->SetAttrs(attrs);
94   auto value_node = std::make_shared<ValueNode>(prim);
95   std::vector<AnfNodePtr> new_inputs = {value_node};
96   new_inputs.insert(new_inputs.end(), inputs.begin(), inputs.end());
97   auto cnode = graph_->NewCNode(new_inputs);
98   return cnode;
99 }
100 
NewCNode(const std::string & prim_name,const std::vector<AnfNodePtr> & inputs,const mindspore::HashMap<std::string,ValuePtr> & attrs)101 CNodePtr ConstructGraph::NewCNode(const std::string &prim_name, const std::vector<AnfNodePtr> &inputs,
102                                   const mindspore::HashMap<std::string, ValuePtr> &attrs) {
103   auto cnode = NewCNodeWithoutInfer(prim_name, inputs, attrs);
104   AbstractBasePtrList args;
105   std::transform(inputs.begin(), inputs.end(), std::back_inserter(args),
106                  [](const AnfNodePtr &node) -> abstract::AbstractBasePtr { return node->abstract(); });
107   auto out_abs = opt::CppInferShapeAndType(GetCNodePrimitive(cnode), args);
108   cnode->set_abstract(out_abs);
109   return cnode;
110 }
111 
SetGeneralBuildInfo(const AnfNodePtr & node)112 void ConstructGraph::SetGeneralBuildInfo(const AnfNodePtr &node) {
113   kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
114   auto cnode = node->cast<CNodePtr>();
115   MS_EXCEPTION_IF_NULL(cnode);
116   size_t input_num = cnode->size() - 1;
117   info_builder.SetInputsFormat(std::vector<std::string>(input_num, "DefaultFormat"));
118   std::vector<TypeId> input_types(input_num);
119   for (size_t i = 0; i < input_types.size(); i++) {
120     input_types[i] = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, i);
121   }
122   info_builder.SetInputsDeviceType(input_types);
123   info_builder.SetInputsKernelObjectType(
124     std::vector<kernel::KernelObjectType>(input_num, kernel::KernelObjectType::TENSOR));
125   size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(node->abstract());
126   info_builder.SetOutputsFormat(std::vector<std::string>(output_num, "DefaultFormat"));
127   std::vector<TypeId> output_types(output_num);
128   for (size_t i = 0; i < output_types.size(); i++) {
129     input_types[i] = common::AnfAlgo::GetOutputInferDataType(node, i);
130   }
131   info_builder.SetOutputsDeviceType(output_types);
132   info_builder.SetOutputsKernelObjectType(
133     std::vector<kernel::KernelObjectType>(output_num, kernel::KernelObjectType::TENSOR));
134   AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), node.get());
135 }
136 
SetOutput(const AnfNodePtr & node)137 void ConstructGraph::SetOutput(const AnfNodePtr &node) {
138   MS_EXCEPTION_IF_NULL(graph_);
139   graph_->set_output(node, true);
140 }
141 }  // namespace mindspore::test
142