• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/ops/populate/populate_register.h"
17 #include "nnacl/batch_to_space.h"
18 using mindspore::schema::PrimitiveType_BatchToSpace;
19 using mindspore::schema::PrimitiveType_BatchToSpaceND;
20 
21 namespace mindspore {
22 namespace lite {
PopulateBatchToSpaceParameter(const void * prim)23 OpParameter *PopulateBatchToSpaceParameter(const void *prim) {
24   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
25   auto *primitive = static_cast<const schema::Primitive *>(prim);
26   auto value = primitive->value_as_BatchToSpace();
27   MS_CHECK_TRUE_RET(value != nullptr, nullptr);
28 
29   auto *param = reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter)));
30   if (param == nullptr) {
31     MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed.";
32     return nullptr;
33   }
34   memset(param, 0, sizeof(BatchToSpaceParameter));
35 
36   param->op_parameter_.type_ = primitive->value_type();
37   auto block_size = value->block_size();
38   if (block_size == nullptr) {
39     return reinterpret_cast<OpParameter *>(param);
40   }
41   auto block_shape = std::vector<int64_t>(block_size->begin(), block_size->end());
42   if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
43     MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE;
44     free(param);
45     return nullptr;
46   }
47 
48   auto crop = value->crops();
49   if (crop == nullptr) {
50     MS_LOG(ERROR) << "crop is nullptr";
51     free(param);
52     return nullptr;
53   }
54   auto fb_crops = crop->data();
55   if (fb_crops == nullptr) {
56     MS_LOG(ERROR) << "fb_crops is nullptr";
57     free(param);
58     return nullptr;
59   }
60   std::vector<int64_t> crops;
61   for (auto fb_crop : *fb_crops) {
62     auto crops_data = fb_crop->data();
63     if (crops_data == nullptr) {
64       MS_LOG(ERROR) << "crops_data is nullptr";
65       free(param);
66       return nullptr;
67     }
68     auto crops_vec = std::vector<int64_t>(crops_data->begin(), crops_data->end());
69     crops.insert(crops.end(), crops_vec.begin(), crops_vec.end());
70   }
71   if (crops.size() != COMM_SHAPE_SIZE) {
72     MS_LOG(ERROR) << "batch_to_space crops size should be " << COMM_SHAPE_SIZE;
73     free(param);
74     return nullptr;
75   }
76 
77   for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
78     param->block_shape_[i] = static_cast<int>(block_shape[i]);
79   }
80 
81   for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
82     param->crops_[i] = static_cast<int>(crops[i]);
83   }
84   return reinterpret_cast<OpParameter *>(param);
85 }
86 
87 REG_POPULATE(PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter, SCHEMA_CUR)
88 REG_POPULATE(PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter, SCHEMA_CUR)
89 }  // namespace lite
90 }  // namespace mindspore
91