1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/passes/dead_code_elimination.h>
3 #include <torch/csrc/jit/passes/erase_number_types.h>
4 #include <torch/csrc/jit/passes/onnx.h>
5 #include <torch/csrc/jit/passes/onnx/pattern_conversion/common.h>
6 #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h>
7
8 #include <ATen/ScalarOps.h>
9
10 #include <iostream>
11
12 // EDITING THIS FILE? READ THIS FIRST!
13 // see Note [Edit Pattern Conversion] in pattern_conversion.h
14
15 namespace torch {
16 namespace jit {
17
18 // Converting inplace index_put to ONNX
19 namespace {
20
CreateSizeOfDim(Value * input,int64_t dim,Node * insertBefore)21 Value* CreateSizeOfDim(Value* input, int64_t dim, Node* insertBefore) {
22 auto graph = input->owningGraph();
23 WithInsertPoint guard(insertBefore);
24 auto size = graph->insert(aten::size, {input, dim});
25 return size;
26 }
27
ConvertSelectToIndex(Value * index,Node * insertBefore)28 Value* ConvertSelectToIndex(Value* index, Node* insertBefore) {
29 // Create index tensor based on index input of aten::select node.
30 auto graph = insertBefore->owningGraph();
31 WithInsertPoint guard(insertBefore);
32 return graph->insert(aten::unsqueeze, {index, 0});
33 }
34
ConvertSliceToIndex(Node * slice,Value * size,Node * insertBefore)35 Value* ConvertSliceToIndex(Node* slice, Value* size, Node* insertBefore) {
36 // Create index tensor based on aten::slice node.
37 auto graph = slice->owningGraph();
38 WithInsertPoint guard(insertBefore);
39 TORCH_INTERNAL_ASSERT((slice->inputs()).size() == 5);
40 auto start = slice->inputs()[2];
41 auto end = slice->inputs()[3];
42 auto step = slice->inputs()[4];
43 auto index =
44 graph->insert(aten::arange, {size}, {NamedValue("dtype", c10::kLong)});
45 auto sliced_index_n = graph->create(
46 aten::slice,
47 {index,
48 graph->insertConstant(
49 scalar_to_tensor(at::Scalar(0)), std::nullopt, slice->scope()),
50 start,
51 end,
52 step});
53
54 sliced_index_n->copyMetadata(insertBefore);
55 auto sliced_index = sliced_index_n->insertBefore(insertBefore)->output();
56 return sliced_index;
57 }
58
59 struct ConvertedIndex {
ConvertedIndextorch::jit::__anon3f1df3b40111::ConvertedIndex60 ConvertedIndex(Value* index, c10::Symbol orig_node_kind)
61 : index(index), orig_node_kind(orig_node_kind) {}
62
63 Value* index = nullptr;
64 c10::Symbol orig_node_kind;
65 };
66
MergeSliceAndSelectToIndices(Graph * graph,Node * index_put_node,const std::vector<Node * > & slice_and_select_nodes,Value * orig_data,const py::dict & env)67 std::unordered_map<int64_t, ConvertedIndex> MergeSliceAndSelectToIndices(
68 Graph* graph,
69 Node* index_put_node,
70 const std::vector<Node*>& slice_and_select_nodes,
71 Value* orig_data,
72 const py::dict& env) {
73 std::unordered_map<int64_t, ConvertedIndex> dim_index_map;
74
75 // Loop over fetched slice and select nodes and convert them to index tensors.
76 // keep track of which dimension the current slice/select node is applying to.
77 int64_t cur_dim = 0;
78 int64_t dim_offset = 0;
79 const auto orig_tensor_indices = index_put_node->input(1)->node()->inputs();
80 for (auto it = slice_and_select_nodes.rbegin();
81 it != slice_and_select_nodes.rend();
82 ++it) {
83 auto node = *it;
84 // select does not keep dims,
85 // this creates offset for latter slice and select nodes.
86 // NOTE: Cannot rely on get(attr::dim), because op no longer match schema.
87 int64_t dim = node->inputs().at(1)->node()->t(attr::value).item().toLong();
88
89 if (dim < 0) {
90 // auto input_type = env.at(orig_data)->type()->expect<TensorType>();
91 auto py_value = env[py::cast(orig_data)];
92 Value* value = py_value.cast<Value*>();
93 auto input_type = value->type()->expect<TensorType>();
94 if (input_type->dim().has_value()) {
95 auto rank = static_cast<int64_t>(input_type->dim().value());
96 // Rank of original tensor to index on.
97 // Minus the offset created by select operators.
98 dim = dim + rank - dim_offset;
99 } else {
100 std::cerr
101 << "Error: Cannot export ellipsis indexing for input "
102 << "of unknown rank. Check https://pytorch.org/docs/stable/onnx.html#indexing"
103 << "for details.";
104 }
105 }
106 dim = dim + dim_offset;
107 while (cur_dim < dim) {
108 // Handle skipped dims, these are created from ..., or tensor indices
109 // E.g.: x[torch.tensor([1, 0]), ..., 0] = update, where x has rank 3.
110 // Both torch.tensor([1, 0]) and ... are skipped, we only observe
111 // aten::select node with dim == 2. Tensor indices will be handled later.
112 // Ellipsis(...) are treated as a complete slice over the axes, thus we
113 // create index tensors here accordingly.
114 if (cur_dim - dim_offset >= (int64_t)orig_tensor_indices.size() ||
115 index_put_node->input(1)
116 ->node()
117 ->input(cur_dim - dim_offset)
118 ->node()
119 ->mustBeNone()) {
120 auto size = CreateSizeOfDim(orig_data, cur_dim, index_put_node);
121 WithInsertPoint guard(index_put_node);
122 auto index_tensor = graph->insert(
123 aten::arange, {size}, {NamedValue("dtype", c10::kLong)});
124 dim_index_map.emplace(
125 std::piecewise_construct,
126 std::forward_as_tuple(cur_dim),
127 std::forward_as_tuple(index_tensor, aten::slice));
128 } else if (cur_dim - dim_offset < (int64_t)orig_tensor_indices.size()) {
129 dim_index_map.emplace(
130 std::piecewise_construct,
131 std::forward_as_tuple(cur_dim),
132 std::forward_as_tuple(
133 orig_tensor_indices[cur_dim - dim_offset], aten::index));
134 }
135 cur_dim++;
136 }
137
138 TORCH_INTERNAL_ASSERT(cur_dim == dim);
139 if (node->kind() == aten::slice) {
140 auto size = CreateSizeOfDim(orig_data, dim, index_put_node);
141 auto index_tensor = ConvertSliceToIndex(node, size, index_put_node);
142 dim_index_map.emplace(
143 std::piecewise_construct,
144 std::forward_as_tuple(dim),
145 std::forward_as_tuple(index_tensor, aten::slice));
146 } else if (node->kind() == aten::select) {
147 auto index_tensor = ConvertSelectToIndex(node->input(2), index_put_node);
148 dim_index_map.emplace(
149 std::piecewise_construct,
150 std::forward_as_tuple(dim),
151 std::forward_as_tuple(index_tensor, aten::select));
152 dim_offset++;
153 } else {
154 TORCH_CHECK(
155 false,
156 node->kind().toDisplayString(),
157 " Expected aten::slice or aten::select.");
158 }
159
160 cur_dim++;
161 }
162
163 while (cur_dim - dim_offset < (int64_t)orig_tensor_indices.size()) {
164 dim_index_map.emplace(
165 std::piecewise_construct,
166 std::forward_as_tuple(cur_dim),
167 std::forward_as_tuple(
168 orig_tensor_indices[cur_dim - dim_offset], aten::index));
169 cur_dim++;
170 }
171
172 // Each dimension should have its associated index tensor.
173 TORCH_INTERNAL_ASSERT((int64_t)dim_index_map.size() == cur_dim);
174 return dim_index_map;
175 }
176
177 // Convert slice/select operators to tensor indices.
178 // Reshape the tensor indices according to their axis.
179 // E.g. x[1:3, 0, ind1, ind2] = y
180 // slice index shape: [2, 1, 1 ]
181 // select index shape: [ 1, 1 ]
182 // ind1 shape: [ _ ]
183 // ind2 shape: [ _ ]
184 // where _ is the original size of ind1 and ind2.
185 // ind1 and ind2 are both 1-d tensors since currently we only supports 1-d
186 // tensor indices.
ReshapeToAdvancedIndexingFormat(Graph * graph,Node * index_put_node,std::unordered_map<int64_t,ConvertedIndex> & dim_index_map)187 std::vector<Value*> ReshapeToAdvancedIndexingFormat(
188 Graph* graph,
189 Node* index_put_node,
190 std::unordered_map<int64_t, ConvertedIndex>& dim_index_map) {
191 std::vector<Value*> indices;
192
193 size_t min_index_dim = dim_index_map.size();
194 size_t max_index_dim = 0;
195 size_t tensor_ind_count = 0;
196 for (const auto i : c10::irange(dim_index_map.size())) {
197 auto index_i = dim_index_map.find(i);
198 TORCH_INTERNAL_ASSERT(index_i != dim_index_map.end());
199 if (index_i->second.orig_node_kind == aten::index) {
200 if (i < min_index_dim)
201 min_index_dim = i;
202 if (i > max_index_dim)
203 max_index_dim = i;
204 tensor_ind_count++;
205 }
206 }
207
208 if (((max_index_dim - min_index_dim + 1) != tensor_ind_count) &&
209 tensor_ind_count != 0) {
210 TORCH_CHECK(
211 false,
212 "Only consecutive 1-d tensor indices are supported in exporting aten::index_put to ONNX.",
213 "Check https://pytorch.org/docs/stable/onnx.html#indexing for details");
214 }
215
216 size_t tensor_ind_offset = tensor_ind_count == 0 ? 0 : tensor_ind_count - 1;
217 WithInsertPoint guard(index_put_node);
218 for (const auto i : c10::irange(dim_index_map.size())) {
219 size_t ind_size = 0;
220 auto index_i = dim_index_map.find(i);
221 TORCH_INTERNAL_ASSERT(index_i != dim_index_map.end());
222 Value* index = index_i->second.index;
223 switch (index_i->second.orig_node_kind) {
224 case aten::select:
225 case aten::slice: {
226 if (i < min_index_dim) {
227 ind_size = dim_index_map.size() - tensor_ind_offset - i;
228 } else {
229 ind_size = dim_index_map.size() - i;
230 }
231 break;
232 }
233
234 case aten::index: {
235 ind_size = dim_index_map.size() - tensor_ind_offset - min_index_dim;
236 break;
237 }
238 default:
239 TORCH_CHECK(
240 false, "Unexpected node kind ", index_i->second.orig_node_kind);
241 }
242
243 if (ind_size != 1) {
244 std::vector<int64_t> view_shape(ind_size, 1);
245 view_shape[0] = -1;
246 auto unsqueezed_index = graph->insert(aten::view, {index, view_shape});
247 indices.emplace_back(unsqueezed_index);
248 } else {
249 indices.emplace_back(index);
250 }
251 }
252
253 return indices;
254 }
255
256 // Trace back all the slice & select nodes associated with the index_put node,
257 // and convert them to associated indices.
258 // E.g. The IR for x[1:3, 0] = update
259 // ...
260 // %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
261 // ...
262 // %11 : Float(2) = aten::select(%8, %9, %10)
263 // ...
264 // %13 : Tensor?[] = prim::ListConstruct()
265 // ...
266 // %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
267 // The aten::index_put node alone does not contain any indices (%13 : Tensor?[]
268 // = prim::ListConstruct()).
269 // ...
270 // # Below constructs index from slice node.
271 // %23 : Long() = aten::size(%0, %4)
272 // %28 : Tensor = aten::arange(%23, %24, %25, %26, %27)
273 // %33 : Tensor = aten::slice(%28, %4, %5, %6, %7)
274 // %39 : int[] = prim::Constant[value=[-1, 1]]()
275 // %40 : Tensor = aten::view(%33, %39)
276 // ...
277 // # Below constructs index from select node.
278 // %36 : int = prim::Constant[value=0]()
279 // %37 : Tensor = aten::unsqueeze(%10, %36)
280 // %42 : int[] = prim::Constant[value=[-1]]()
281 // %43 : Tensor = aten::view(%37, %42)
282 // ...
283 // # Adding the above two indices to index_put
284 // %44 : Tensor?[] = prim::ListConstruct(%40, %43)
285 // %45 : Float(2, 5) = aten::index_put(%0, %44, %14, %15)
ConvertIndexPutToONNX(Block * new_block,Node * old_node,py::dict & env,py::set & values_in_env)286 std::vector<Value*> ConvertIndexPutToONNX(
287 Block* new_block,
288 Node* old_node,
289 py::dict& env,
290 py::set& values_in_env) {
291 if (old_node->kind() != Symbol::fromQualString("onnx::Placeholder") ||
292 (old_node->s(attr::name) != "index_put" &&
293 old_node->s(attr::name) != "index_put_")) {
294 return {};
295 }
296
297 TORCH_INTERNAL_ASSERT(old_node->blocks().size() == 1);
298 auto old_graph = old_node->owningGraph();
299 auto subblock = old_node->blocks()[0];
300 auto index_put_node = subblock->nodes().back()->prev();
301
302 // Find slice and select operators that are associated with this index
303 // operator. E.g. x[1:3, 0] = y will generate one slice operator(1:3) and one
304 // select operator(0).
305 std::vector<Node*> slice_and_select_nodes =
306 IndexingPatternFinder::FetchSliceAndSelect(index_put_node);
307 Node* last_node = !slice_and_select_nodes.empty()
308 ? slice_and_select_nodes.back()
309 : index_put_node;
310 // Update inner block input originates from outside.
311 last_node->replaceInput(0, old_node->input(0));
312 Value* orig_data = last_node->input(0);
313
314 // Convert slice and select operators to indices.
315 std::unordered_map<int64_t, ConvertedIndex> dim_index_map =
316 MergeSliceAndSelectToIndices(
317 old_graph, index_put_node, slice_and_select_nodes, orig_data, env);
318
319 // Reshape indices to advanced indexing format.
320 std::vector<Value*> indices =
321 ReshapeToAdvancedIndexingFormat(old_graph, index_put_node, dim_index_map);
322
323 // Create new index_put node with converted indices.
324 const auto list_indices =
325 old_graph->createList(OptionalType::ofTensor(), indices)
326 ->insertBefore(index_put_node)
327 ->output();
328 auto new_index_put_node = old_graph->create(
329 aten::index_put,
330 {orig_data,
331 list_indices,
332 index_put_node->input(2),
333 index_put_node->input(3)});
334 new_index_put_node->insertBefore(index_put_node);
335 new_index_put_node->copyMetadata(index_put_node);
336 auto new_index_put = new_index_put_node->output();
337 new_index_put->copyMetadata(index_put_node->output());
338 index_put_node->output()->replaceAllUsesWith(new_index_put);
339
340 // Convert aten type to onnx type.
341 EraseNumberTypesOnBlock(subblock);
342 EliminateDeadCode(
343 subblock,
344 true,
345 DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
346
347 // Convert all the new aten nodes that were just created to onnx.
348 // New onnx nodes are appended at the end of new_block.
349 for (auto at_n : subblock->nodes()) {
350 if (at_n == subblock->param_node() || at_n == subblock->return_node()) {
351 continue;
352 }
353
354 NodeToONNX(
355 at_n,
356 new_block,
357 torch::onnx::OperatorExportTypes::ONNX,
358 env,
359 values_in_env);
360 }
361
362 // Find onnx outputs corresponding to the aten outputs of index_put.
363 std::vector<Value*> outs;
364 for (auto o : subblock->return_node()->inputs()) {
365 auto py_value = env[py::cast(o)];
366 Value* value = py_value.cast<Value*>();
367 outs.emplace_back(value);
368 }
369 return outs;
370 }
371
372 } // namespace
373
ConvertPatternFromSubblock(Block * new_block,Node * old_node,py::dict & env,py::set & values_in_env)374 std::vector<Value*> ConvertPatternFromSubblock(
375 Block* new_block,
376 Node* old_node,
377 py::dict& env,
378 py::set& values_in_env) {
379 std::vector<Value*> res;
380
381 if (old_node->kind() != Symbol::fromQualString("onnx::Placeholder")) {
382 return res;
383 }
384
385 // The pattern conversion code should not alter nodes outside the Placeholder
386 // subblock.
387 auto op_name = old_node->s(attr::name);
388 if (op_name == "index_put" || op_name == "index_put_") {
389 res = ConvertIndexPutToONNX(new_block, old_node, env, values_in_env);
390 }
391
392 return res;
393 }
394
395 } // namespace jit
396 } // namespace torch
397