• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "schema/model_v0_generated.h"
18 #include "src/ops/populate/populate_register.h"
19 #include "nnacl/split_parameter.h"
20 #include "nnacl/op_base.h"
21 
22 namespace mindspore {
23 namespace lite {
24 namespace {
DestroySplitParameter(OpParameter * parameter)25 void DestroySplitParameter(OpParameter *parameter) {
26   MS_CHECK_PTR_IF_NULL(parameter);
27   auto param = reinterpret_cast<SplitParameter *>(parameter);
28   if (param->split_sizes_ != nullptr) {
29     free(param->split_sizes_);
30     param->split_sizes_ = nullptr;
31   }
32 }
33 
PopulateSplitParameter(const void * prim)34 OpParameter *PopulateSplitParameter(const void *prim) {
35   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
36   auto *primitive = static_cast<const schema::v0::Primitive *>(prim);
37   auto split_prim = primitive->value_as_Split();
38   if (split_prim == nullptr) {
39     MS_LOG(ERROR) << "split_prim is nullptr";
40     return nullptr;
41   }
42   auto *split_param = reinterpret_cast<SplitParameter *>(malloc(sizeof(SplitParameter)));
43   if (split_param == nullptr) {
44     MS_LOG(ERROR) << "malloc SplitParameter failed.";
45     return nullptr;
46   }
47   memset(split_param, 0, sizeof(SplitParameter));
48   split_param->op_parameter_.type_ = schema::PrimitiveType_Split;
49   split_param->num_split_ = split_prim->numberSplit();
50   if (split_param->num_split_ > std::numeric_limits<int>::max() / static_cast<int>(sizeof(int)) ||
51       split_param->num_split_ <= 0) {
52     MS_LOG(ERROR) << "The value of split_param->num_split_ is out of range.";
53     free(split_param);
54     return nullptr;
55   }
56   int *split_sizes = reinterpret_cast<int *>(malloc(static_cast<size_t>(split_param->num_split_) * sizeof(int)));
57   if (split_sizes == nullptr) {
58     MS_LOG(ERROR) << "malloc split size of SplitParameter failed.";
59     free(split_param);
60     return nullptr;
61   }
62   split_param->op_parameter_.destroy_func_ = DestroySplitParameter;
63   memset(split_sizes, 0, static_cast<size_t>(split_param->num_split_) * sizeof(int));
64   split_param->split_sizes_ = split_sizes;
65   auto split_sizes_vector_ = split_prim->sizeSplits();
66   if (split_sizes_vector_ != nullptr) {
67     int i = 0;
68     for (auto iter = split_sizes_vector_->begin(); iter != split_sizes_vector_->end(); ++iter) {
69       split_param->split_sizes_[i++] = *iter;
70     }
71     split_param->split_count_ = split_param->num_split_;
72   } else {
73     split_param->split_count_ = 0;
74   }
75   split_param->split_dim_ = split_prim->splitDim();
76   return reinterpret_cast<OpParameter *>(split_param);
77 }
78 }  // namespace
79 
80 Registry g_splitV0ParameterRegistry(schema::v0::PrimitiveType_Split, PopulateSplitParameter, SCHEMA_V0);
81 }  // namespace lite
82 }  // namespace mindspore
83