1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_compiler/func_wrapper.h"
17 #include <algorithm>
18 #include <iterator>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <vector>
23 #include "pipeline/jit/pi/graph_compiler/pi_ir/custom_nodes.h"
24 #include "pipeline/jit/pi/graph_compiler/pi_ir/value.h"
25 #include "utils/log_adapter.h"
26
27 namespace mindspore {
28 namespace pijit {
GetInputs()29 const ValuePtrList &InputsCollector::GetInputs() {
30 VISIT_NODE_LIST(nodes_)
31 return inputs_;
32 }
33
Visit_(const ir::LoadValueNodePtr & node)34 void InputsCollector::Visit_(const ir::LoadValueNodePtr &node) {
35 if (node->GetOpCode() == LOAD_FAST) {
36 AddInput(node->GetArg(0));
37 } else {
38 VISIT_NODE_LIST(node->GetArgs())
39 }
40 }
41
Visit_(const ir::StoreNodePtr & node)42 void InputsCollector::Visit_(const ir::StoreNodePtr &node) {
43 Visit(node->GetArg(0));
44 if (node->GetOpCode() == STORE_FAST) {
45 AddAssignedVar(node->GetArg(1));
46 }
47 }
48
AddInput(const ir::NodePtr & input)49 void InputsCollector::AddInput(const ir::NodePtr &input) {
50 MS_EXCEPTION_IF_CHECK_FAIL(input->isa<ir::Value>(), input->ToString() + " is not excepted.");
51 auto value = input->cast<ir::ValuePtr>();
52 if ((assigned_vars_.find(value->GetValue()) == assigned_vars_.end()) &&
53 (input_names_.find(value->GetValue()) == input_names_.end())) {
54 input_names_.insert(value->GetValue());
55 inputs_.push_back(value);
56 }
57 }
58
AddAssignedVar(const ir::NodePtr & var)59 void InputsCollector::AddAssignedVar(const ir::NodePtr &var) {
60 MS_EXCEPTION_IF_CHECK_FAIL(var->isa<ir::Value>(), var->ToString() + " is not excepted.");
61 auto value = var->cast<ir::ValuePtr>();
62 assigned_vars_.insert(value->GetValue());
63 }
64
GetOutputs()65 const ValuePtrList &OutputsCollector::GetOutputs() {
66 VISIT_NODE_LIST(nodes_)
67 return outputs_;
68 }
69
Visit_(const ir::StoreNodePtr & node)70 void OutputsCollector::Visit_(const ir::StoreNodePtr &node) {
71 if (node->GetOpCode() == STORE_FAST) {
72 AddOutput(node->GetRightArg());
73 } else {
74 Visit(node->GetLeftArg());
75 Visit(node->GetRightArg());
76 }
77 }
78
AddOutput(const ir::NodePtr & output)79 void OutputsCollector::AddOutput(const ir::NodePtr &output) {
80 MS_EXCEPTION_IF_CHECK_FAIL(output->isa<ir::Value>(), output->ToString() + " is not excepted.");
81 auto value = output->cast<ir::ValuePtr>();
82 if (output_names_.find(value->GetValue()) == output_names_.end()) {
83 output_names_.insert(value->GetValue());
84 outputs_.push_back(value);
85 }
86 }
87
Wrapper()88 ir::FunctionNodePtr FuncWrapper::Wrapper() {
89 (void)GetOutputs();
90 GenerateReturn();
91 GenerateParameters();
92 return func_;
93 }
94
GetOutputs()95 const ValuePtrList &FuncWrapper::GetOutputs() {
96 if (outputs_.empty()) {
97 auto output_collector = std::make_shared<OutputsCollector>(func_->GetNodes());
98 outputs_ = output_collector->GetOutputs();
99 }
100 return outputs_;
101 }
102
GenerateParameters() const103 void FuncWrapper::GenerateParameters() const {
104 auto intput_collector = std::make_shared<InputsCollector>(func_->GetNodes());
105 auto inputs = intput_collector->GetInputs();
106 size_t index = 0;
107 std::for_each(inputs.begin(), inputs.end(), [&index, this](const ir::ValuePtr &input) {
108 std::string name = py::cast<std::string>(input->GetValue());
109 ir::ParameterPtr param = std::make_shared<ir::Parameter>(index, name);
110 // Set arg as positional parameter
111 param->SetCategory(ir::Parameter::POSITIONAL);
112 func_->AddParameter(param);
113 index++;
114 });
115 func_->SetPosArgsCnt(index);
116 }
117
GenerateReturn() const118 void FuncWrapper::GenerateReturn() const {
119 auto nodes = func_->GetNodes();
120 if (!nodes.empty() && nodes.back()->isa<ir::ReturnNode>()) {
121 return;
122 }
123 MS_EXCEPTION_IF_CHECK_FAIL(!outputs_.empty(), "Output can not be empty.");
124 ir::NodePtrList opnds;
125 std::transform(outputs_.begin(), outputs_.end(), std::back_inserter(opnds),
126 [](const ir::ValuePtr &value) { return std::make_shared<ir::LoadValueNode>(LOAD_FAST, value); });
127 ir::NodePtr tuple = std::make_shared<ir::BuildNode>(BUILD_TUPLE, opnds);
128 ir::ReturnNodePtr ret = std::make_shared<ir::ReturnNode>(tuple);
129 func_->AddNode(ret);
130 }
131
132 } // namespace pijit
133 } // namespace mindspore
134