1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "schema/model_v0_generated.h"
18 #include "src/ops/populate/populate_register.h"
19 #include "nnacl/split_parameter.h"
20 #include "nnacl/op_base.h"
21
22 namespace mindspore {
23 namespace lite {
24 namespace {
DestroySplitParameter(OpParameter * parameter)25 void DestroySplitParameter(OpParameter *parameter) {
26 MS_CHECK_PTR_IF_NULL(parameter);
27 auto param = reinterpret_cast<SplitParameter *>(parameter);
28 if (param->split_sizes_ != nullptr) {
29 free(param->split_sizes_);
30 param->split_sizes_ = nullptr;
31 }
32 }
33
PopulateSplitParameter(const void * prim)34 OpParameter *PopulateSplitParameter(const void *prim) {
35 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
36 auto *primitive = static_cast<const schema::v0::Primitive *>(prim);
37 auto split_prim = primitive->value_as_Split();
38 if (split_prim == nullptr) {
39 MS_LOG(ERROR) << "split_prim is nullptr";
40 return nullptr;
41 }
42 auto *split_param = reinterpret_cast<SplitParameter *>(malloc(sizeof(SplitParameter)));
43 if (split_param == nullptr) {
44 MS_LOG(ERROR) << "malloc SplitParameter failed.";
45 return nullptr;
46 }
47 memset(split_param, 0, sizeof(SplitParameter));
48 split_param->op_parameter_.type_ = schema::PrimitiveType_Split;
49 split_param->num_split_ = split_prim->numberSplit();
50 if (split_param->num_split_ > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int)) ||
51 split_param->num_split_ <= 0) {
52 MS_LOG(ERROR) << "The value of split_param->num_split_ is out of range.";
53 free(split_param);
54 return nullptr;
55 }
56 int *split_sizes = reinterpret_cast<int *>(malloc(static_cast<size_t>(split_param->num_split_) * sizeof(int)));
57 if (split_sizes == nullptr) {
58 MS_LOG(ERROR) << "malloc split size of SplitParameter failed.";
59 free(split_param);
60 return nullptr;
61 }
62 split_param->op_parameter_.destroy_func_ = DestroySplitParameter;
63 memset(split_sizes, 0, static_cast<size_t>(split_param->num_split_) * sizeof(int));
64 split_param->split_sizes_ = split_sizes;
65 auto split_sizes_vector_ = split_prim->sizeSplits();
66 if (split_sizes_vector_ != nullptr) {
67 int i = 0;
68 for (auto iter = split_sizes_vector_->begin(); iter != split_sizes_vector_->end(); ++iter) {
69 split_param->split_sizes_[i++] = *iter;
70 }
71 split_param->split_count_ = split_param->num_split_;
72 } else {
73 split_param->split_count_ = 0;
74 }
75 split_param->split_dim_ = split_prim->splitDim();
76 return reinterpret_cast<OpParameter *>(split_param);
77 }
78 } // namespace
79
80 Registry g_splitV0ParameterRegistry(schema::v0::PrimitiveType_Split, PopulateSplitParameter, SCHEMA_V0);
81 } // namespace lite
82 } // namespace mindspore
83