• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_SLICE_TO_TUPLE_H
17 #define MINDSPORE_SLICE_TO_TUPLE_H
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 #include <string>
22 #include <map>
23 
24 #include "frontend/optimizer/optimizer_caller.h"
25 #include "mindspore/core/ops/structure_ops.h"
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "frontend/optimizer/anf_visitor.h"
29 #include "frontend/operator/ops.h"
30 #include "frontend/optimizer/irpass.h"
31 #include "frontend/optimizer/optimizer.h"
32 #include "include/common/utils/utils.h"
33 
34 namespace mindspore {
35 namespace opt {
36 namespace irpass {
37 // {prim::kPrimSliceGetItem, S, "start"} => {prim::kPrimTupleGetItem, S, 0}
38 // {prim::kPrimSliceGetItem, S, "stop"} => {prim::kPrimTupleGetItem, S, 1}
39 // {prim::kPrimSliceGetItem, S, "step"} => {prim::kPrimTupleGetItem, S, 2}
40 // {prim::kPrimMakeSlice, {X, Y, Z}} => {prim::kPrimMakeTuple, {X, Y, Z}}
41 class SliceToTuple : public AnfVisitor {
42  public:
operator()43   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
44     if (IsPrimitiveCNode(node, prim::kPrimMakeSlice)) {
45       auto make_slice = node->cast<CNodePtr>();
46       auto make_tuple_inputs = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimMakeTuple)};
47       std::copy(make_slice->inputs().cbegin() + 1, make_slice->inputs().cend(), std::back_inserter(make_tuple_inputs));
48       return make_slice->func_graph()->NewCNode(make_tuple_inputs);
49     }
50     if (IsPrimitiveCNode(node, prim::kPrimSliceGetItem)) {
51       auto slice_getitem = node->cast<CNodePtr>();
52       auto slice_getitem_slice_input = slice_getitem->input(1);
53       auto slice_getitem_item_input = slice_getitem->input(2);
54       if (!IsValueNode<StringImm>(slice_getitem_item_input)) {
55         return nullptr;
56       }
57       auto vnode = slice_getitem_item_input->cast<ValueNodePtr>();
58       auto slice_attr = GetValue<std::string>(vnode->value());
59       static const std::map<std::string, size_t> kSliceAttrToStaticIndex = {
60         {kSliceStart, 0}, {kSliceStop, 1}, {kSliceStep, 2}};
61       auto iter = kSliceAttrToStaticIndex.find(slice_attr);
62       if (iter == kSliceAttrToStaticIndex.end()) {
63         MS_EXCEPTION(ValueError) << "The slice must be [start, stop, step], but got " << slice_attr;
64       }
65       auto getitem_tuple_inputs =
66         std::vector<AnfNodePtr>{NewValueNode(prim::kPrimTupleGetItem), slice_getitem_slice_input,
67                                 NewValueNode(MakeValue<int64_t>(iter->second))};
68       return slice_getitem->func_graph()->NewCNode(getitem_tuple_inputs);
69     }
70     return nullptr;
71   }
72 };
73 }  // namespace irpass
74 }  // namespace opt
75 }  // namespace mindspore
76 #endif  // MINDSPORE_SLICE_TO_TUPLE_H
77