1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/runtime/kernel/arm/fp32/batch_to_space_fp32.h"
17 #include "schema/model_generated.h"
18 #include "src/kernel_registry.h"
19
20 using mindspore::lite::KernelRegistrar;
21 using mindspore::lite::RET_ERROR;
22 using mindspore::lite::RET_OK;
23 using mindspore::schema::PrimitiveType_BatchToSpace;
24 using mindspore::schema::PrimitiveType_BatchToSpaceND;
25
26 namespace mindspore::kernel {
Processinput()27 int BatchToSpaceCPUKernel::Processinput() {
28 CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
29 CHECK_NULL_RETURN(in_tensors_[DIMENSION_1D]);
30 CHECK_NULL_RETURN(in_tensors_[DIMENSION_2D]);
31 auto block_shape_data = in_tensors_[DIMENSION_1D]->data();
32 auto crops_data = in_tensors_[DIMENSION_2D]->data();
33 CHECK_NULL_RETURN(block_shape_data);
34 CHECK_NULL_RETURN(crops_data);
35 auto block_shape = static_cast<int *>(block_shape_data);
36 auto crops = static_cast<int *>(crops_data);
37 CHECK_LESS_RETURN(in_tensors_[DIMENSION_1D]->ElementsNum(), BATCH_TO_SPACE_BLOCK_SHAPE_SIZE);
38 CHECK_LESS_RETURN(in_tensors_[DIMENSION_2D]->ElementsNum(), COMM_SHAPE_SIZE);
39 for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
40 block_shape_[i] = block_shape[i];
41 }
42 no_crop_ = true;
43 for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
44 crops_[i] = crops[i];
45 if (crops_[i] != 0) {
46 no_crop_ = false;
47 }
48 }
49 return RET_OK;
50 }
51
Init()52 int BatchToSpaceCPUKernel::Init() {
53 CHECK_LESS_RETURN(in_tensors_.size(), 1);
54 CHECK_LESS_RETURN(out_tensors_.size(), 1);
55 MS_ASSERT(in_tensors_[0]->format() == mindspore::NHWC);
56 if (!InferShapeDone()) {
57 return RET_OK;
58 }
59 return ReSize();
60 }
61
ReSize()62 int BatchToSpaceCPUKernel::ReSize() {
63 MS_ASSERT(in_tensors_[0]->shape().size() == COMM_SHAPE_SIZE);
64 return RET_OK;
65 }
66
Run()67 int BatchToSpaceCPUKernel::Run() {
68 auto input = in_tensors_[0];
69 auto output = out_tensors_[0];
70 CHECK_NULL_RETURN(input);
71 CHECK_NULL_RETURN(output);
72 const float *input_data = reinterpret_cast<const float *>(input->data());
73 float *output_data = reinterpret_cast<float *>(output->data());
74 auto in_shape = input->shape();
75 auto out_shape = output->shape();
76 if (in_tensors_.size() == 1) {
77 BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
78 if (param->no_crop_) {
79 BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
80 sizeof(float));
81 } else {
82 BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_,
83 sizeof(float));
84 }
85 }
86 if (in_tensors_.size() == 3) {
87 auto ret = Processinput();
88 if (ret != RET_OK) {
89 MS_LOG(ERROR) << "Processinput failed in BatchToSpace.";
90 return ret;
91 }
92 if (no_crop_) {
93 BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, sizeof(float));
94 } else {
95 BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_, sizeof(float));
96 }
97 }
98 return RET_OK;
99 }
100
101 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, LiteKernelCreator<BatchToSpaceCPUKernel>)
102 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpaceND, LiteKernelCreator<BatchToSpaceCPUKernel>)
103 } // namespace mindspore::kernel
104