1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/ops/populate/populate_register.h"
17 #include "nnacl/batch_to_space.h"
18 using mindspore::schema::PrimitiveType_BatchToSpace;
19 using mindspore::schema::PrimitiveType_BatchToSpaceND;
20
21 namespace mindspore {
22 namespace lite {
PopulateBatchToSpaceParameter(const void * prim)23 OpParameter *PopulateBatchToSpaceParameter(const void *prim) {
24 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
25 auto *primitive = static_cast<const schema::Primitive *>(prim);
26 auto value = primitive->value_as_BatchToSpace();
27 MS_CHECK_TRUE_RET(value != nullptr, nullptr);
28
29 auto *param = reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter)));
30 if (param == nullptr) {
31 MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed.";
32 return nullptr;
33 }
34 memset(param, 0, sizeof(BatchToSpaceParameter));
35
36 param->op_parameter_.type_ = primitive->value_type();
37 auto block_size = value->block_size();
38 if (block_size == nullptr) {
39 return reinterpret_cast<OpParameter *>(param);
40 }
41 auto block_shape = std::vector<int64_t>(block_size->begin(), block_size->end());
42 if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
43 MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE;
44 free(param);
45 return nullptr;
46 }
47
48 auto crop = value->crops();
49 if (crop == nullptr) {
50 MS_LOG(ERROR) << "crop is nullptr";
51 free(param);
52 return nullptr;
53 }
54 auto fb_crops = crop->data();
55 if (fb_crops == nullptr) {
56 MS_LOG(ERROR) << "fb_crops is nullptr";
57 free(param);
58 return nullptr;
59 }
60 std::vector<int64_t> crops;
61 for (auto fb_crop : *fb_crops) {
62 auto crops_data = fb_crop->data();
63 if (crops_data == nullptr) {
64 MS_LOG(ERROR) << "crops_data is nullptr";
65 free(param);
66 return nullptr;
67 }
68 auto crops_vec = std::vector<int64_t>(crops_data->begin(), crops_data->end());
69 crops.insert(crops.end(), crops_vec.begin(), crops_vec.end());
70 }
71 if (crops.size() != COMM_SHAPE_SIZE) {
72 MS_LOG(ERROR) << "batch_to_space crops size should be " << COMM_SHAPE_SIZE;
73 free(param);
74 return nullptr;
75 }
76
77 for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
78 param->block_shape_[i] = static_cast<int>(block_shape[i]);
79 }
80
81 for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
82 param->crops_[i] = static_cast<int>(crops[i]);
83 }
84 return reinterpret_cast<OpParameter *>(param);
85 }
86
87 REG_POPULATE(PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter, SCHEMA_CUR)
88 REG_POPULATE(PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter, SCHEMA_CUR)
89 } // namespace lite
90 } // namespace mindspore
91