1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <memory>
18 #include <string>
19 #include <vector>
20
21 #include "torch/csrc/jit/passes/freeze_module.h"
22 #include "torch/csrc/jit/passes/inliner.h"
23 #include "torch/csrc/jit/passes/remove_mutation.h"
24 #include "torch/csrc/jit/passes/normalize_ops.h"
25 #include "mindspore/core/utils/log_adapter.h"
26 #include "mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser.h"
27
28 namespace torch {
29 namespace jit {
OutputsUnpack(Graph * graph)30 void OutputsUnpack(Graph *graph) {
31 std::function<void(Node * tuple, std::vector<Node *> &, std::vector<Value *> &)> flattenTuple =
32 [&flattenTuple](Node *tuple, std::vector<Node *> &tuples, std::vector<Value *> &values) -> void {
33 tuples.push_back(tuple);
34 for (auto input : tuple->inputs()) {
35 auto node = input->node();
36 if (node->kind() == prim::TupleConstruct) {
37 flattenTuple(node, tuples, values);
38 } else {
39 values.push_back(input);
40 }
41 }
42 };
43 for (size_t i = 0; i < graph->outputs().size(); i++) {
44 auto node = graph->outputs()[i]->node();
45 // unpack output
46 switch (node->kind()) {
47 case prim::TupleConstruct: {
48 std::vector<Node *> tuples;
49 std::vector<Value *> values;
50 flattenTuple(node, tuples, values);
51 for (auto realOutput : values) {
52 graph->registerOutput(realOutput);
53 }
54 graph->eraseOutput(i);
55 for (auto tuple : tuples) {
56 if (!tuple->hasUses()) {
57 tuple->destroy();
58 }
59 }
60 break;
61 }
62 case prim::DictConstruct: {
63 graph->registerOutput(node->input(1));
64 graph->eraseOutput(i);
65 node->destroy();
66 break;
67 }
68 case prim::ListConstruct: {
69 for (size_t j = 0; i < node->inputs().size(); j++) {
70 graph->registerOutput(node->input(j));
71 }
72 graph->eraseOutput(i);
73 node->destroy();
74 break;
75 }
76 default: {
77 MS_LOG(INFO) << "skip " << mindspore::lite::PytorchNodeParser::GetTorchNodeType(node);
78 break;
79 }
80 }
81 }
82 }
83
FuseListUnpack(Block * block)84 void FuseListUnpack(Block *block) {
85 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
86 auto *node = *it;
87 it++;
88
89 for (Block *sub_block : node->blocks()) {
90 FuseListUnpack(sub_block);
91 }
92 std::set<NodeKind> fusekind = {
93 aten::split,
94 aten::split_with_sizes,
95 aten::split_with_sizes,
96 aten::unsafe_split_with_sizes,
97 aten::unbind,
98 aten::chunk,
99 aten::unsafe_chunk,
100 aten::where,
101 };
102 if (fusekind.count(it->kind()) && it->outputs().size() == 1 && it->output()->uses().size() == 1) {
103 const auto listunpack = it->output()->uses()[0].user;
104 if (listunpack->kind() == prim::ListUnpack) {
105 for (size_t i = 0; i < listunpack->outputs().size(); ++i) {
106 auto new_output = it->addOutput();
107 new_output->copyMetadata(listunpack->output(i));
108 }
109 listunpack->removeAllInputs();
110 it->eraseOutput(0);
111 listunpack->replaceAllUsesWith(*it);
112 listunpack->destroy();
113 }
114 }
115 }
116 }
117
118 /*
119 Remove all ListConstruct op with only one input and not used by aten::cat, like below:
120 %116 : Tensor?[] = prim::ListConstruct(%115)
121 %alpha0.1 : Tensor = aten::index_put_(%alpha.1, %116, %x.1, %16)
122 ListConstruct used by aten::cat will be reserved like below:
123 %features.2 : Tensor[] = prim::ListConstruct(%input3.4)
124 %concated_features.380 : Tensor = aten::cat(%features.2, %5)
125 Attention: Running this pass after removeListAppend
126 */
RemoveListConstructOps(Block * block)127 void RemoveListConstructOps(Block *block) {
128 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; ++it) {
129 if (it->kind() == prim::ListConstruct && it->inputs().size() == 1) {
130 bool remove = true;
131 for (auto use : it->output()->uses()) {
132 if (use.user->kind() == aten::cat) {
133 remove = false;
134 break;
135 }
136 }
137 if (remove) {
138 it->output()->replaceAllUsesWith(it->input(0));
139 it->removeInput(0);
140 it.destroyCurrent();
141 }
142 }
143 }
144 }
145
146 // flatten tuple input and remove tuple unpack
FlattenInputsTuple(Graph * graph)147 bool FlattenInputsTuple(Graph *graph) {
148 for (size_t i = 0; i < graph->inputs().size(); i++) {
149 auto input_value = graph->inputs()[i];
150 auto tuple = input_value->type()->cast<at::TupleType>();
151 if (!tuple) {
152 continue;
153 }
154
155 auto use_list = input_value->uses();
156 if (use_list.size() != 1) {
157 MS_LOG(ERROR) << "current pass only supports tuple input has only one user!";
158 return false;
159 }
160 auto tuple_unpack = use_list[0].user;
161 auto node_type = mindspore::lite::PytorchNodeParser::GetTorchNodeType(tuple_unpack);
162 if (node_type != "TupleUnpack") {
163 MS_LOG(ERROR) << "unsupported node user type of tuple: " << node_type;
164 return false;
165 }
166
167 auto elements = tuple->elements();
168 size_t idx = 0;
169 for (auto &element : elements) {
170 auto new_input = graph->addInput(tuple_unpack->output(idx)->debugName());
171 new_input->setType(element);
172
173 auto tuple_item = tuple_unpack->output(idx);
174 auto item_use_list = tuple_item->uses();
175 for (const auto &use : item_use_list) {
176 use.user->replaceInputWith(tuple_item, new_input);
177 }
178 idx++;
179 }
180 tuple_unpack->destroy();
181 graph->eraseInput(i);
182 }
183 return true;
184 }
185
TorchGraphTransform(Module * module)186 std::shared_ptr<Graph> TorchGraphTransform(Module *module) {
187 module->eval(); // eval to expand function call
188 auto mod = torch::jit::freeze_module(*module); // freeze module
189 auto torch_graph = mod.get_method("forward").graph();
190 if (torch_graph == nullptr) {
191 return nullptr;
192 }
193 // parse submodules in graph
194 torch::jit::Inline(*torch_graph);
195 torch::jit::NormalizeOps(torch_graph);
196
197 RemoveListConstructOps(torch_graph->block());
198 FlattenInputsTuple(torch_graph.get());
199 FuseListUnpack(torch_graph->block());
200
201 OutputsUnpack(torch_graph.get());
202 return torch_graph;
203 }
204 } // namespace jit
205 } // namespace torch
206