• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "mapper/slice_mapper.h"
18 #include <memory>
19 #include <utility>
20 #include <algorithm>
21 #include <vector>
22 #include "common/op_attr.h"
23 #include "common/anf_util.h"
24 #include "common/op_enum.h"
25 #include "common/check_base.h"
26 #include "ops/auto_generate/gen_lite_ops.h"
27 #include "op/slice_operator.h"
28 
29 namespace mindspore {
30 namespace dpico {
Map(const api::CNodePtr & cnode,std::vector<BaseOperatorPtr> * base_operators,const api::PrimitivePtr & prim,const api::CNodePtrList & output_cnodes)31 STATUS SliceMapper::Map(const api::CNodePtr &cnode, std::vector<BaseOperatorPtr> *base_operators,
32                         const api::PrimitivePtr &prim, const api::CNodePtrList &output_cnodes) {
33   if (base_operators == nullptr) {
34     MS_LOG(ERROR) << "base_operators is nullptr.";
35     return RET_ERROR;
36   }
37   auto split_prim = api::utils::cast<api::SharedPtr<ops::Split>>(prim);
38   MS_ASSERT(split_prim != nullptr);
39 
40   auto slice_operator = std::make_unique<mapper::SliceOperator>();
41   MS_CHECK_TRUE_MSG(slice_operator != nullptr, RET_ERROR, "slice_operator is nullptr.");
42 
43   if (SetCommonAttr(cnode, slice_operator.get(), output_cnodes) != RET_OK) {
44     MS_LOG(ERROR) << "set common attr failed. " << cnode->fullname_with_scope();
45     return RET_ERROR;
46   }
47 
48   slice_operator->SetOpType(mapper::OpType::SLICE);
49   ShapeVector shape;
50   if (GetInputShapeFromCNode(cnode, kInputIndex1, &shape) != RET_OK) {
51     MS_LOG(ERROR) << "fetch input shape failed.";
52     return RET_ERROR;
53   }
54   if (std::any_of(shape.begin(), shape.end(),
55                   [](int64_t dim) { return dim <= 0 || dim > static_cast<int64_t>(UINT32_MAX); })) {
56     MS_LOG(ERROR) << "shape is invalid, which is not larger than 0 and less than uint32_max";
57     return RET_ERROR;
58   }
59   MS_ASSERT(shape.size() <= kDims4);
60   if (split_prim->GetAttr(ops::kAxis) == nullptr) {
61     MS_LOG(ERROR) << "axis attr is nullptr, please check split_checker.";
62     return RET_ERROR;
63   }
64   auto split_axis = split_prim->get_axis();
65   split_axis = split_axis < 0 ? split_axis + static_cast<int64_t>(shape.size()) : split_axis;
66   if (split_axis > static_cast<int64_t>(kDims4)) {
67     MS_LOG(ERROR) << "split axis is invalid.";
68     return RET_ERROR;
69   }
70   slice_operator->SetAxis(static_cast<int32_t>(split_axis));
71   if (split_prim->GetAttr(ops::kSizeSplits) != nullptr) {
72     auto sizes = api::GetValue<std::vector<int64_t>>(split_prim->GetAttr("size_splits"));
73     if (sizes.empty()) {
74       MS_LOG(ERROR) << "sizes shouldn't be empty." << cnode->fullname_with_scope();
75       return RET_ERROR;
76     }
77     if (std::any_of(sizes.begin(), sizes.end(), [](int64_t size) { return size > static_cast<int64_t>(UINT32_MAX); })) {
78       MS_LOG(ERROR) << "split sizes is invalid, which is not larger than 0 and less than uint32_max";
79       return RET_ERROR;
80     }
81     std::vector<uint32_t> sizes_u;
82     (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_u),
83                          [](int64_t size) { return static_cast<uint32_t>(size); });
84     uint32_t slice_point_cnt = 0;
85     for (size_t i = 0; i < sizes_u.size() - 1; i++) {
86       if (sizes_u.at(i) >= (static_cast<uint32_t>(shape[split_axis]) - slice_point_cnt)) {
87         MS_LOG(ERROR) << "split sizes is invalid, which is larger than the related dim.";
88         return RET_ERROR;
89       }
90       slice_operator->AddSlicePoint(sizes_u.at(i) + slice_point_cnt);
91       slice_point_cnt += sizes_u.at(i);
92     }
93   }
94 
95   if (slice_operator->GetSlicePointVec().empty()) {
96     if (split_prim->GetAttr(ops::kOutputNum) == nullptr) {
97       MS_LOG(ERROR) << "cannot determine split points.";
98       return RET_ERROR;
99     }
100     auto output_num = api::GetValue<int64_t>(split_prim->GetAttr(ops::kOutputNum));
101     MS_CHECK_TRUE_MSG(output_num != 0, RET_ERROR, "output_num is 0.");
102     if (shape[split_axis] % output_num != 0) {
103       MS_LOG(ERROR) << "output_num is 0 or split op is invalid, which input shape cannot be splited.";
104       return RET_ERROR;
105     }
106     uint32_t size_of_each_out = static_cast<uint32_t>(shape[split_axis] / output_num);
107     for (uint32_t i = 1; i < static_cast<uint32_t>(output_num); ++i) {
108       slice_operator->AddSlicePoint(i * size_of_each_out);
109     }
110   }
111   base_operators->push_back(std::move(slice_operator));
112   return RET_OK;
113 }
114 REG_MAPPER(Split, SliceMapper)
115 }  // namespace dpico
116 }  // namespace mindspore
117