• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/reshape.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/tensor_c_utils.h"
20 #include "nnacl/nnacl_common.h"
21 
22 int kMinCostPerThread = 16384;
23 
ParallelReshape(void * param,int task_id,float l,float r)24 int ParallelReshape(void *param, int task_id, float l, float r) {
25   NNACL_CHECK_NULL_RETURN_ERR(param);
26   ReshapeStruct *reshape = (ReshapeStruct *)param;
27 
28   int data_size = (int)DataTypeCSize(reshape->base_.in_[0]->data_type_);
29   uint8_t *in_start = (uint8_t *)(reshape->base_.in_[0]->data_) + task_id * reshape->block_num_ * data_size;
30   uint8_t *out_start = (uint8_t *)(reshape->base_.out_[0]->data_) + task_id * reshape->block_num_ * data_size;
31   int copy_num = reshape->block_num_;
32   if (task_id == (reshape->base_.thread_nr_ - 1)) {
33     copy_num = reshape->total_num_ - task_id * reshape->block_num_;
34   }
35   (void)memcpy(out_start, in_start, copy_num * data_size);
36   return NNACL_OK;
37 }
38 
ReshapeResize(struct KernelBase * self)39 int ReshapeResize(struct KernelBase *self) {
40   NNACL_CHECK_NULL_RETURN_ERR(self);
41   ReshapeStruct *reshape = (ReshapeStruct *)self;
42   reshape->total_num_ = GetElementNum(self->in_[0]);
43   if (reshape->total_num_ == 0) {
44     return NNACL_OK;
45   }
46 
47   self->thread_nr_ = MSMIN(self->thread_nr_, UP_DIV(reshape->total_num_, kMinCostPerThread));
48   if (self->thread_nr_ < 1) {
49     self->thread_nr_ = 1;
50   }
51   NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_);
52   reshape->block_num_ = UP_DIV(reshape->total_num_, self->thread_nr_);
53   NNACL_CHECK_ZERO_RETURN_ERR(reshape->block_num_);
54   self->thread_nr_ = UP_DIV(reshape->total_num_, reshape->block_num_);
55 
56   return NNACL_OK;
57 }
58 
ReshapeCompute(struct KernelBase * self)59 int ReshapeCompute(struct KernelBase *self) {
60   return self->env_->ParallelLaunch(self->env_->thread_pool_, ParallelReshape, self, self->thread_nr_);
61 }
62 
CreateReshape(OpParameter * param,int data_type)63 KernelBase *CreateReshape(OpParameter *param, int data_type) {
64   ReshapeStruct *reshape = (ReshapeStruct *)malloc(sizeof(ReshapeStruct));
65   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reshape);
66   reshape->base_.Release = DefaultRelease;
67   reshape->base_.Prepare = DefaultPrepare1In1Out;
68   reshape->base_.Resize = ReshapeResize;
69   reshape->base_.Compute = ReshapeCompute;
70   return (KernelBase *)reshape;
71 }
72 
73 REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeInt32, CreateReshape)
74 REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeFloat32, CreateReshape)
75 REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeFloat16, CreateReshape)
76 REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeBool, CreateReshape)
77 REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeInt32, CreateReshape)
78 REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeFloat16, CreateReshape)
79 REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeFloat32, CreateReshape)
80 REG_KERNEL_CREATOR(PrimType_FlattenGrad, kNumberTypeFloat16, CreateReshape)
81 REG_KERNEL_CREATOR(PrimType_FlattenGrad, kNumberTypeFloat32, CreateReshape)
82 REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeInt32, CreateReshape)
83 REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeFloat16, CreateReshape)
84 REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeFloat32, CreateReshape)
85 REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeBool, CreateReshape)
86 REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeInt8, CreateReshape)
87 REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeFloat32, CreateReshape)
88 REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeFloat16, CreateReshape)
89 REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeInt32, CreateReshape)
90 REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeBool, CreateReshape)
91 REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeFloat16, CreateReshape)
92 REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeFloat32, CreateReshape)
93 REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeUInt8, CreateReshape)
94 REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeInt32, CreateReshape)
95 REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeInt64, CreateReshape)
96 REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeBool, CreateReshape)
97