• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_compiler/func_wrapper.h"
17 #include <algorithm>
18 #include <iterator>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <vector>
23 #include "pipeline/jit/pi/graph_compiler/pi_ir/custom_nodes.h"
24 #include "pipeline/jit/pi/graph_compiler/pi_ir/value.h"
25 #include "utils/log_adapter.h"
26 
27 namespace mindspore {
28 namespace pijit {
GetInputs()29 const ValuePtrList &InputsCollector::GetInputs() {
30   VISIT_NODE_LIST(nodes_)
31   return inputs_;
32 }
33 
Visit_(const ir::LoadValueNodePtr & node)34 void InputsCollector::Visit_(const ir::LoadValueNodePtr &node) {
35   if (node->GetOpCode() == LOAD_FAST) {
36     AddInput(node->GetArg(0));
37   } else {
38     VISIT_NODE_LIST(node->GetArgs())
39   }
40 }
41 
Visit_(const ir::StoreNodePtr & node)42 void InputsCollector::Visit_(const ir::StoreNodePtr &node) {
43   Visit(node->GetArg(0));
44   if (node->GetOpCode() == STORE_FAST) {
45     AddAssignedVar(node->GetArg(1));
46   }
47 }
48 
AddInput(const ir::NodePtr & input)49 void InputsCollector::AddInput(const ir::NodePtr &input) {
50   MS_EXCEPTION_IF_CHECK_FAIL(input->isa<ir::Value>(), input->ToString() + " is not excepted.");
51   auto value = input->cast<ir::ValuePtr>();
52   if ((assigned_vars_.find(value->GetValue()) == assigned_vars_.end()) &&
53       (input_names_.find(value->GetValue()) == input_names_.end())) {
54     input_names_.insert(value->GetValue());
55     inputs_.push_back(value);
56   }
57 }
58 
AddAssignedVar(const ir::NodePtr & var)59 void InputsCollector::AddAssignedVar(const ir::NodePtr &var) {
60   MS_EXCEPTION_IF_CHECK_FAIL(var->isa<ir::Value>(), var->ToString() + " is not excepted.");
61   auto value = var->cast<ir::ValuePtr>();
62   assigned_vars_.insert(value->GetValue());
63 }
64 
GetOutputs()65 const ValuePtrList &OutputsCollector::GetOutputs() {
66   VISIT_NODE_LIST(nodes_)
67   return outputs_;
68 }
69 
Visit_(const ir::StoreNodePtr & node)70 void OutputsCollector::Visit_(const ir::StoreNodePtr &node) {
71   if (node->GetOpCode() == STORE_FAST) {
72     AddOutput(node->GetRightArg());
73   } else {
74     Visit(node->GetLeftArg());
75     Visit(node->GetRightArg());
76   }
77 }
78 
AddOutput(const ir::NodePtr & output)79 void OutputsCollector::AddOutput(const ir::NodePtr &output) {
80   MS_EXCEPTION_IF_CHECK_FAIL(output->isa<ir::Value>(), output->ToString() + " is not excepted.");
81   auto value = output->cast<ir::ValuePtr>();
82   if (output_names_.find(value->GetValue()) == output_names_.end()) {
83     output_names_.insert(value->GetValue());
84     outputs_.push_back(value);
85   }
86 }
87 
Wrapper()88 ir::FunctionNodePtr FuncWrapper::Wrapper() {
89   (void)GetOutputs();
90   GenerateReturn();
91   GenerateParameters();
92   return func_;
93 }
94 
GetOutputs()95 const ValuePtrList &FuncWrapper::GetOutputs() {
96   if (outputs_.empty()) {
97     auto output_collector = std::make_shared<OutputsCollector>(func_->GetNodes());
98     outputs_ = output_collector->GetOutputs();
99   }
100   return outputs_;
101 }
102 
GenerateParameters() const103 void FuncWrapper::GenerateParameters() const {
104   auto intput_collector = std::make_shared<InputsCollector>(func_->GetNodes());
105   auto inputs = intput_collector->GetInputs();
106   size_t index = 0;
107   std::for_each(inputs.begin(), inputs.end(), [&index, this](const ir::ValuePtr &input) {
108     std::string name = py::cast<std::string>(input->GetValue());
109     ir::ParameterPtr param = std::make_shared<ir::Parameter>(index, name);
110     // Set arg as positional parameter
111     param->SetCategory(ir::Parameter::POSITIONAL);
112     func_->AddParameter(param);
113     index++;
114   });
115   func_->SetPosArgsCnt(index);
116 }
117 
GenerateReturn() const118 void FuncWrapper::GenerateReturn() const {
119   auto nodes = func_->GetNodes();
120   if (!nodes.empty() && nodes.back()->isa<ir::ReturnNode>()) {
121     return;
122   }
123   MS_EXCEPTION_IF_CHECK_FAIL(!outputs_.empty(), "Output can not be empty.");
124   ir::NodePtrList opnds;
125   std::transform(outputs_.begin(), outputs_.end(), std::back_inserter(opnds),
126                  [](const ir::ValuePtr &value) { return std::make_shared<ir::LoadValueNode>(LOAD_FAST, value); });
127   ir::NodePtr tuple = std::make_shared<ir::BuildNode>(BUILD_TUPLE, opnds);
128   ir::ReturnNodePtr ret = std::make_shared<ir::ReturnNode>(tuple);
129   func_->AddNode(ret);
130 }
131 
132 }  // namespace pijit
133 }  // namespace mindspore
134