1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "mapper/slice_mapper.h"
18 #include <memory>
19 #include <utility>
20 #include <algorithm>
21 #include <vector>
22 #include "common/op_attr.h"
23 #include "common/anf_util.h"
24 #include "common/op_enum.h"
25 #include "common/check_base.h"
26 #include "ops/auto_generate/gen_lite_ops.h"
27 #include "op/slice_operator.h"
28
29 namespace mindspore {
30 namespace dpico {
Map(const api::CNodePtr & cnode,std::vector<BaseOperatorPtr> * base_operators,const api::PrimitivePtr & prim,const api::CNodePtrList & output_cnodes)31 STATUS SliceMapper::Map(const api::CNodePtr &cnode, std::vector<BaseOperatorPtr> *base_operators,
32 const api::PrimitivePtr &prim, const api::CNodePtrList &output_cnodes) {
33 if (base_operators == nullptr) {
34 MS_LOG(ERROR) << "base_operators is nullptr.";
35 return RET_ERROR;
36 }
37 auto split_prim = api::utils::cast<api::SharedPtr<ops::Split>>(prim);
38 MS_ASSERT(split_prim != nullptr);
39
40 auto slice_operator = std::make_unique<mapper::SliceOperator>();
41 MS_CHECK_TRUE_MSG(slice_operator != nullptr, RET_ERROR, "slice_operator is nullptr.");
42
43 if (SetCommonAttr(cnode, slice_operator.get(), output_cnodes) != RET_OK) {
44 MS_LOG(ERROR) << "set common attr failed. " << cnode->fullname_with_scope();
45 return RET_ERROR;
46 }
47
48 slice_operator->SetOpType(mapper::OpType::SLICE);
49 ShapeVector shape;
50 if (GetInputShapeFromCNode(cnode, kInputIndex1, &shape) != RET_OK) {
51 MS_LOG(ERROR) << "fetch input shape failed.";
52 return RET_ERROR;
53 }
54 if (std::any_of(shape.begin(), shape.end(),
55 [](int64_t dim) { return dim <= 0 || dim > static_cast<int64_t>(UINT32_MAX); })) {
56 MS_LOG(ERROR) << "shape is invalid, which is not larger than 0 and less than uint32_max";
57 return RET_ERROR;
58 }
59 MS_ASSERT(shape.size() <= kDims4);
60 if (split_prim->GetAttr(ops::kAxis) == nullptr) {
61 MS_LOG(ERROR) << "axis attr is nullptr, please check split_checker.";
62 return RET_ERROR;
63 }
64 auto split_axis = split_prim->get_axis();
65 split_axis = split_axis < 0 ? split_axis + static_cast<int64_t>(shape.size()) : split_axis;
66 if (split_axis > static_cast<int64_t>(kDims4)) {
67 MS_LOG(ERROR) << "split axis is invalid.";
68 return RET_ERROR;
69 }
70 slice_operator->SetAxis(static_cast<int32_t>(split_axis));
71 if (split_prim->GetAttr(ops::kSizeSplits) != nullptr) {
72 auto sizes = api::GetValue<std::vector<int64_t>>(split_prim->GetAttr("size_splits"));
73 if (sizes.empty()) {
74 MS_LOG(ERROR) << "sizes shouldn't be empty." << cnode->fullname_with_scope();
75 return RET_ERROR;
76 }
77 if (std::any_of(sizes.begin(), sizes.end(), [](int64_t size) { return size > static_cast<int64_t>(UINT32_MAX); })) {
78 MS_LOG(ERROR) << "split sizes is invalid, which is not larger than 0 and less than uint32_max";
79 return RET_ERROR;
80 }
81 std::vector<uint32_t> sizes_u;
82 (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_u),
83 [](int64_t size) { return static_cast<uint32_t>(size); });
84 uint32_t slice_point_cnt = 0;
85 for (size_t i = 0; i < sizes_u.size() - 1; i++) {
86 if (sizes_u.at(i) >= (static_cast<uint32_t>(shape[split_axis]) - slice_point_cnt)) {
87 MS_LOG(ERROR) << "split sizes is invalid, which is larger than the related dim.";
88 return RET_ERROR;
89 }
90 slice_operator->AddSlicePoint(sizes_u.at(i) + slice_point_cnt);
91 slice_point_cnt += sizes_u.at(i);
92 }
93 }
94
95 if (slice_operator->GetSlicePointVec().empty()) {
96 if (split_prim->GetAttr(ops::kOutputNum) == nullptr) {
97 MS_LOG(ERROR) << "cannot determine split points.";
98 return RET_ERROR;
99 }
100 auto output_num = api::GetValue<int64_t>(split_prim->GetAttr(ops::kOutputNum));
101 MS_CHECK_TRUE_MSG(output_num != 0, RET_ERROR, "output_num is 0.");
102 if (shape[split_axis] % output_num != 0) {
103 MS_LOG(ERROR) << "output_num is 0 or split op is invalid, which input shape cannot be splited.";
104 return RET_ERROR;
105 }
106 uint32_t size_of_each_out = static_cast<uint32_t>(shape[split_axis] / output_num);
107 for (uint32_t i = 1; i < static_cast<uint32_t>(output_num); ++i) {
108 slice_operator->AddSlicePoint(i * size_of_each_out);
109 }
110 }
111 base_operators->push_back(std::move(slice_operator));
112 return RET_OK;
113 }
114 REG_MAPPER(Split, SliceMapper)
115 } // namespace dpico
116 } // namespace mindspore
117