• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/fp32/space_to_batch_fp32.h"
17 #include "nnacl/errorcode.h"
18 
DoSpaceToBatch(const void * input,void * output,SpaceToBatchParameter * param,int task_id)19 int DoSpaceToBatch(const void *input, void *output, SpaceToBatchParameter *param, int task_id) {
20   if (param->op_parameter_.thread_num_ == 0) {
21     return NNACL_ERR;
22   }
23   const int input_batch = param->input_shape_[0];
24   const int input_height = param->input_shape_[1];
25   const int input_width = param->input_shape_[2];
26 
27   const int output_batch = param->output_shape_[0];
28   const int output_height = param->output_shape_[1];
29   const int output_width = param->output_shape_[2];
30 
31   const int block_shape_height = param->block_sizes_[0];
32   const int block_shape_width = param->block_sizes_[1];
33   const int padding_top = param->paddings_[0];
34   const int padding_left = param->paddings_[2];
35 
36   NNACL_CHECK_ZERO_RETURN_ERR(input_batch);
37   NNACL_CHECK_ZERO_RETURN_ERR(block_shape_width);
38   int copy_size = param->input_shape_[3] * param->data_type_len;
39   for (int64_t out_b = task_id; out_b < output_batch; out_b += param->op_parameter_.thread_num_) {
40     int in_b = out_b % input_batch;
41     int shift_w = (out_b / input_batch) % block_shape_width;
42     int shift_h = (out_b / input_batch) / block_shape_width;
43     for (int out_h = 0; out_h < output_height; out_h++) {
44       for (int out_w = 0; out_w < output_width; out_w++) {
45         int64_t output_offset =
46           out_b * param->out_stride_[0] + out_h * param->out_stride_[1] + out_w * param->out_stride_[2];
47         if (out_h * block_shape_height + shift_h < padding_top ||
48             out_h * block_shape_height + shift_h >= padding_top + input_height ||
49             out_w * block_shape_width + shift_w < padding_left ||
50             out_w * block_shape_width + shift_w >= padding_left + input_width) {
51           memset((int8_t *)output + output_offset * param->data_type_len, 0, copy_size);
52         } else {
53           int in_h = (out_h * block_shape_height + shift_h) - padding_top;
54           int in_w = (out_w * block_shape_width + shift_w) - padding_left;
55           int input_offset = in_b * param->in_stride_[0] + in_h * param->in_stride_[1] + in_w * param->in_stride_[2];
56           memcpy((int8_t *)output + output_offset * param->data_type_len,
57                  (const int8_t *)input + input_offset * param->data_type_len, copy_size);
58         }
59       }
60     }
61   }
62   return NNACL_OK;
63 }
64