• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/runtime/kernel/arm/fp32/batch_to_space_fp32.h"
17 #include "schema/model_generated.h"
18 #include "src/kernel_registry.h"
19 
20 using mindspore::lite::KernelRegistrar;
21 using mindspore::lite::RET_ERROR;
22 using mindspore::lite::RET_OK;
23 using mindspore::schema::PrimitiveType_BatchToSpace;
24 using mindspore::schema::PrimitiveType_BatchToSpaceND;
25 
26 namespace mindspore::kernel {
Processinput()27 int BatchToSpaceCPUKernel::Processinput() {
28   CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
29   CHECK_NULL_RETURN(in_tensors_[DIMENSION_1D]);
30   CHECK_NULL_RETURN(in_tensors_[DIMENSION_2D]);
31   auto block_shape_data = in_tensors_[DIMENSION_1D]->data();
32   auto crops_data = in_tensors_[DIMENSION_2D]->data();
33   CHECK_NULL_RETURN(block_shape_data);
34   CHECK_NULL_RETURN(crops_data);
35   auto block_shape = static_cast<int *>(block_shape_data);
36   auto crops = static_cast<int *>(crops_data);
37   CHECK_LESS_RETURN(in_tensors_[DIMENSION_1D]->ElementsNum(), BATCH_TO_SPACE_BLOCK_SHAPE_SIZE);
38   CHECK_LESS_RETURN(in_tensors_[DIMENSION_2D]->ElementsNum(), COMM_SHAPE_SIZE);
39   for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
40     block_shape_[i] = block_shape[i];
41   }
42   no_crop_ = true;
43   for (int i = 0; i < COMM_SHAPE_SIZE; ++i) {
44     crops_[i] = crops[i];
45     if (crops_[i] != 0) {
46       no_crop_ = false;
47     }
48   }
49   return RET_OK;
50 }
51 
Init()52 int BatchToSpaceCPUKernel::Init() {
53   CHECK_LESS_RETURN(in_tensors_.size(), 1);
54   CHECK_LESS_RETURN(out_tensors_.size(), 1);
55   MS_ASSERT(in_tensors_[0]->format() == mindspore::NHWC);
56   if (!InferShapeDone()) {
57     return RET_OK;
58   }
59   return ReSize();
60 }
61 
ReSize()62 int BatchToSpaceCPUKernel::ReSize() {
63   MS_ASSERT(in_tensors_[0]->shape().size() == COMM_SHAPE_SIZE);
64   return RET_OK;
65 }
66 
Run()67 int BatchToSpaceCPUKernel::Run() {
68   auto input = in_tensors_[0];
69   auto output = out_tensors_[0];
70   CHECK_NULL_RETURN(input);
71   CHECK_NULL_RETURN(output);
72   const float *input_data = reinterpret_cast<const float *>(input->data());
73   float *output_data = reinterpret_cast<float *>(output->data());
74   auto in_shape = input->shape();
75   auto out_shape = output->shape();
76   if (in_tensors_.size() == 1) {
77     BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_);
78     if (param->no_crop_) {
79       BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_,
80                                 sizeof(float));
81     } else {
82       BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_,
83                           sizeof(float));
84     }
85   }
86   if (in_tensors_.size() == 3) {
87     auto ret = Processinput();
88     if (ret != RET_OK) {
89       MS_LOG(ERROR) << "Processinput failed in BatchToSpace.";
90       return ret;
91     }
92     if (no_crop_) {
93       BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, sizeof(float));
94     } else {
95       BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_, sizeof(float));
96     }
97   }
98   return RET_OK;
99 }
100 
101 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, LiteKernelCreator<BatchToSpaceCPUKernel>)
102 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpaceND, LiteKernelCreator<BatchToSpaceCPUKernel>)
103 }  // namespace mindspore::kernel
104