1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "nnacl/fp32/space_to_batch_fp32.h"
17 #include "nnacl/errorcode.h"
18
DoSpaceToBatch(const void * input,void * output,SpaceToBatchParameter * param,int task_id)19 int DoSpaceToBatch(const void *input, void *output, SpaceToBatchParameter *param, int task_id) {
20 if (param->op_parameter_.thread_num_ == 0) {
21 return NNACL_ERR;
22 }
23 const int input_batch = param->input_shape_[0];
24 const int input_height = param->input_shape_[1];
25 const int input_width = param->input_shape_[2];
26
27 const int output_batch = param->output_shape_[0];
28 const int output_height = param->output_shape_[1];
29 const int output_width = param->output_shape_[2];
30
31 const int block_shape_height = param->block_sizes_[0];
32 const int block_shape_width = param->block_sizes_[1];
33 const int padding_top = param->paddings_[0];
34 const int padding_left = param->paddings_[2];
35
36 NNACL_CHECK_ZERO_RETURN_ERR(input_batch);
37 NNACL_CHECK_ZERO_RETURN_ERR(block_shape_width);
38 int copy_size = param->input_shape_[3] * param->data_type_len;
39 for (int64_t out_b = task_id; out_b < output_batch; out_b += param->op_parameter_.thread_num_) {
40 int in_b = out_b % input_batch;
41 int shift_w = (out_b / input_batch) % block_shape_width;
42 int shift_h = (out_b / input_batch) / block_shape_width;
43 for (int out_h = 0; out_h < output_height; out_h++) {
44 for (int out_w = 0; out_w < output_width; out_w++) {
45 int64_t output_offset =
46 out_b * param->out_stride_[0] + out_h * param->out_stride_[1] + out_w * param->out_stride_[2];
47 if (out_h * block_shape_height + shift_h < padding_top ||
48 out_h * block_shape_height + shift_h >= padding_top + input_height ||
49 out_w * block_shape_width + shift_w < padding_left ||
50 out_w * block_shape_width + shift_w >= padding_left + input_width) {
51 memset((int8_t *)output + output_offset * param->data_type_len, 0, copy_size);
52 } else {
53 int in_h = (out_h * block_shape_height + shift_h) - padding_top;
54 int in_w = (out_w * block_shape_width + shift_w) - padding_left;
55 int input_offset = in_b * param->in_stride_[0] + in_h * param->in_stride_[1] + in_w * param->in_stride_[2];
56 memcpy((int8_t *)output + output_offset * param->data_type_len,
57 (const int8_t *)input + input_offset * param->data_type_len, copy_size);
58 }
59 }
60 }
61 }
62 return NNACL_OK;
63 }
64