1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_SLICE_TO_TUPLE_H 17 #define MINDSPORE_SLICE_TO_TUPLE_H 18 #include <algorithm> 19 #include <memory> 20 #include <vector> 21 #include <string> 22 #include <map> 23 24 #include "frontend/optimizer/optimizer_caller.h" 25 #include "mindspore/core/ops/structure_ops.h" 26 #include "mindspore/core/ops/sequence_ops.h" 27 #include "mindspore/core/ops/framework_ops.h" 28 #include "frontend/optimizer/anf_visitor.h" 29 #include "frontend/operator/ops.h" 30 #include "frontend/optimizer/irpass.h" 31 #include "frontend/optimizer/optimizer.h" 32 #include "include/common/utils/utils.h" 33 34 namespace mindspore { 35 namespace opt { 36 namespace irpass { 37 // {prim::kPrimSliceGetItem, S, "start"} => {prim::kPrimTupleGetItem, S, 0} 38 // {prim::kPrimSliceGetItem, S, "stop"} => {prim::kPrimTupleGetItem, S, 1} 39 // {prim::kPrimSliceGetItem, S, "step"} => {prim::kPrimTupleGetItem, S, 2} 40 // {prim::kPrimMakeSlice, {X, Y, Z}} => {prim::kPrimMakeTuple, {X, Y, Z}} 41 class SliceToTuple : public AnfVisitor { 42 public: operator()43 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 44 if (IsPrimitiveCNode(node, prim::kPrimMakeSlice)) { 45 auto make_slice = node->cast<CNodePtr>(); 46 auto make_tuple_inputs = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimMakeTuple)}; 47 std::copy(make_slice->inputs().cbegin() + 1, make_slice->inputs().cend(), std::back_inserter(make_tuple_inputs)); 48 return make_slice->func_graph()->NewCNode(make_tuple_inputs); 49 } 50 if (IsPrimitiveCNode(node, prim::kPrimSliceGetItem)) { 51 auto slice_getitem = node->cast<CNodePtr>(); 52 auto slice_getitem_slice_input = slice_getitem->input(1); 53 auto slice_getitem_item_input = slice_getitem->input(2); 54 if (!IsValueNode<StringImm>(slice_getitem_item_input)) { 55 return nullptr; 56 } 57 auto vnode = slice_getitem_item_input->cast<ValueNodePtr>(); 58 auto slice_attr = GetValue<std::string>(vnode->value()); 59 static const std::map<std::string, size_t> kSliceAttrToStaticIndex = { 60 {kSliceStart, 0}, {kSliceStop, 1}, {kSliceStep, 2}}; 61 auto iter = kSliceAttrToStaticIndex.find(slice_attr); 62 if (iter == kSliceAttrToStaticIndex.end()) { 63 MS_EXCEPTION(ValueError) << "The slice must be [start, stop, step], but got " << slice_attr; 64 } 65 auto getitem_tuple_inputs = 66 std::vector<AnfNodePtr>{NewValueNode(prim::kPrimTupleGetItem), slice_getitem_slice_input, 67 NewValueNode(MakeValue<int64_t>(iter->second))}; 68 return slice_getitem->func_graph()->NewCNode(getitem_tuple_inputs); 69 } 70 return nullptr; 71 } 72 }; 73 } // namespace irpass 74 } // namespace opt 75 } // namespace mindspore 76 #endif // MINDSPORE_SLICE_TO_TUPLE_H 77