• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <memory>
18 #include <string>
19 #include <vector>
20 
21 #include "torch/csrc/jit/passes/freeze_module.h"
22 #include "torch/csrc/jit/passes/inliner.h"
23 #include "torch/csrc/jit/passes/remove_mutation.h"
24 #include "torch/csrc/jit/passes/normalize_ops.h"
25 #include "mindspore/core/utils/log_adapter.h"
26 #include "mindspore/lite/tools/converter/parser/pytorch/pytorch_node_parser.h"
27 
28 namespace torch {
29 namespace jit {
OutputsUnpack(Graph * graph)30 void OutputsUnpack(Graph *graph) {
31   std::function<void(Node * tuple, std::vector<Node *> &, std::vector<Value *> &)> flattenTuple =
32     [&flattenTuple](Node *tuple, std::vector<Node *> &tuples, std::vector<Value *> &values) -> void {
33     tuples.push_back(tuple);
34     for (auto input : tuple->inputs()) {
35       auto node = input->node();
36       if (node->kind() == prim::TupleConstruct) {
37         flattenTuple(node, tuples, values);
38       } else {
39         values.push_back(input);
40       }
41     }
42   };
43   for (size_t i = 0; i < graph->outputs().size(); i++) {
44     auto node = graph->outputs()[i]->node();
45     // unpack output
46     switch (node->kind()) {
47       case prim::TupleConstruct: {
48         std::vector<Node *> tuples;
49         std::vector<Value *> values;
50         flattenTuple(node, tuples, values);
51         for (auto realOutput : values) {
52           graph->registerOutput(realOutput);
53         }
54         graph->eraseOutput(i);
55         for (auto tuple : tuples) {
56           if (!tuple->hasUses()) {
57             tuple->destroy();
58           }
59         }
60         break;
61       }
62       case prim::DictConstruct: {
63         graph->registerOutput(node->input(1));
64         graph->eraseOutput(i);
65         node->destroy();
66         break;
67       }
68       case prim::ListConstruct: {
69         for (size_t j = 0; i < node->inputs().size(); j++) {
70           graph->registerOutput(node->input(j));
71         }
72         graph->eraseOutput(i);
73         node->destroy();
74         break;
75       }
76       default: {
77         MS_LOG(INFO) << "skip " << mindspore::lite::PytorchNodeParser::GetTorchNodeType(node);
78         break;
79       }
80     }
81   }
82 }
83 
FuseListUnpack(Block * block)84 void FuseListUnpack(Block *block) {
85   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
86     auto *node = *it;
87     it++;
88 
89     for (Block *sub_block : node->blocks()) {
90       FuseListUnpack(sub_block);
91     }
92     std::set<NodeKind> fusekind = {
93       aten::split,
94       aten::split_with_sizes,
95       aten::split_with_sizes,
96       aten::unsafe_split_with_sizes,
97       aten::unbind,
98       aten::chunk,
99       aten::unsafe_chunk,
100       aten::where,
101     };
102     if (fusekind.count(it->kind()) && it->outputs().size() == 1 && it->output()->uses().size() == 1) {
103       const auto listunpack = it->output()->uses()[0].user;
104       if (listunpack->kind() == prim::ListUnpack) {
105         for (size_t i = 0; i < listunpack->outputs().size(); ++i) {
106           auto new_output = it->addOutput();
107           new_output->copyMetadata(listunpack->output(i));
108         }
109         listunpack->removeAllInputs();
110         it->eraseOutput(0);
111         listunpack->replaceAllUsesWith(*it);
112         listunpack->destroy();
113       }
114     }
115   }
116 }
117 
118 /*
119    Remove all ListConstruct op with only one input and not used by aten::cat, like below:
120         %116 : Tensor?[] = prim::ListConstruct(%115)
121         %alpha0.1 : Tensor = aten::index_put_(%alpha.1, %116, %x.1, %16)
122    ListConstruct used by aten::cat will be reserved like below:
123         %features.2 : Tensor[] = prim::ListConstruct(%input3.4)
124         %concated_features.380 : Tensor = aten::cat(%features.2, %5)
125    Attention: Running this pass after removeListAppend
126  */
RemoveListConstructOps(Block * block)127 void RemoveListConstructOps(Block *block) {
128   for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; ++it) {
129     if (it->kind() == prim::ListConstruct && it->inputs().size() == 1) {
130       bool remove = true;
131       for (auto use : it->output()->uses()) {
132         if (use.user->kind() == aten::cat) {
133           remove = false;
134           break;
135         }
136       }
137       if (remove) {
138         it->output()->replaceAllUsesWith(it->input(0));
139         it->removeInput(0);
140         it.destroyCurrent();
141       }
142     }
143   }
144 }
145 
146 // flatten tuple input and remove tuple unpack
FlattenInputsTuple(Graph * graph)147 bool FlattenInputsTuple(Graph *graph) {
148   for (size_t i = 0; i < graph->inputs().size(); i++) {
149     auto input_value = graph->inputs()[i];
150     auto tuple = input_value->type()->cast<at::TupleType>();
151     if (!tuple) {
152       continue;
153     }
154 
155     auto use_list = input_value->uses();
156     if (use_list.size() != 1) {
157       MS_LOG(ERROR) << "current pass only supports tuple input has only one user!";
158       return false;
159     }
160     auto tuple_unpack = use_list[0].user;
161     auto node_type = mindspore::lite::PytorchNodeParser::GetTorchNodeType(tuple_unpack);
162     if (node_type != "TupleUnpack") {
163       MS_LOG(ERROR) << "unsupported node user type of tuple: " << node_type;
164       return false;
165     }
166 
167     auto elements = tuple->elements();
168     size_t idx = 0;
169     for (auto &element : elements) {
170       auto new_input = graph->addInput(tuple_unpack->output(idx)->debugName());
171       new_input->setType(element);
172 
173       auto tuple_item = tuple_unpack->output(idx);
174       auto item_use_list = tuple_item->uses();
175       for (const auto &use : item_use_list) {
176         use.user->replaceInputWith(tuple_item, new_input);
177       }
178       idx++;
179     }
180     tuple_unpack->destroy();
181     graph->eraseInput(i);
182   }
183   return true;
184 }
185 
TorchGraphTransform(Module * module)186 std::shared_ptr<Graph> TorchGraphTransform(Module *module) {
187   module->eval();                                 // eval to expand function call
188   auto mod = torch::jit::freeze_module(*module);  // freeze module
189   auto torch_graph = mod.get_method("forward").graph();
190   if (torch_graph == nullptr) {
191     return nullptr;
192   }
193   // parse submodules in graph
194   torch::jit::Inline(*torch_graph);
195   torch::jit::NormalizeOps(torch_graph);
196 
197   RemoveListConstructOps(torch_graph->block());
198   FlattenInputsTuple(torch_graph.get());
199   FuseListUnpack(torch_graph->block());
200 
201   OutputsUnpack(torch_graph.get());
202   return torch_graph;
203 }
204 }  // namespace jit
205 }  // namespace torch
206