• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/crop.h"
18 #include "nnacl/base/crop_base.h"
19 #include "nnacl/fp32/crop_fp32.h"
20 #include "nnacl/kernel/default_kernel_base.h"
21 #ifdef ENABLE_FP16
22 #include "nnacl/fp16/crop_fp16.h"
23 #endif
24 
CropLaunch(void * cdata,int task_id,float l,float r)25 int CropLaunch(void *cdata, int task_id, float l, float r) {
26   CropStruct *crop = (CropStruct *)cdata;
27   NNACL_CHECK_NULL_RETURN_ERR(crop);
28 
29   TensorC *in = crop->base_.in_[FIRST_INPUT];
30   NNACL_CHECK_NULL_RETURN_ERR(in);
31   TensorC *out = crop->base_.out_[OUTPUT_INDEX];
32   NNACL_CHECK_NULL_RETURN_ERR(out);
33 
34 #ifdef ENABLE_FP16
35   if (in->data_type_ == kNumberTypeFloat16) {
36     Fp16Crop((float16_t *)in->data_, (float16_t *)out->data_, in->shape_, out->shape_, crop->in_offset_,
37              in->shape_size_, task_id, crop->base_.thread_nr_);
38     return NNACL_OK;
39   }
40 #endif
41 
42   CropParameter *crop_param = (CropParameter *)crop->base_.param_;
43   NNACL_CHECK_NULL_RETURN_ERR(crop_param);
44   Crop4D((float *)in->data_, (float *)out->data_, in->shape_, out->shape_, crop_param, task_id, crop->base_.thread_nr_);
45   return NNACL_OK;
46 }
47 
CropResize(struct KernelBase * self)48 int CropResize(struct KernelBase *self) {
49   TensorC *in_tensor = self->in_[FIRST_INPUT];
50   NNACL_CHECK_NULL_RETURN_ERR(in_tensor);
51   TensorC *out_tensor = self->out_[OUTPUT_INDEX];
52   NNACL_CHECK_NULL_RETURN_ERR(out_tensor);
53   NNACL_CHECK_FALSE(out_tensor->shape_size_ <= Num1, NNACL_OUTPUT_TENSOR_ERROR);
54 
55   CropStruct *crop = (CropStruct *)self;
56   NNACL_CHECK_NULL_RETURN_ERR(crop);
57   CropParameter *crop_param = (CropParameter *)self->param_;
58   NNACL_CHECK_NULL_RETURN_ERR(crop_param);
59 
60   return CropPadOffset(in_tensor->shape_size_, crop_param, crop->in_offset_);
61 }
62 
CropCompute(struct KernelBase * self)63 int CropCompute(struct KernelBase *self) {
64   TensorC *in_tensor = self->in_[FIRST_INPUT];
65   NNACL_CHECK_NULL_RETURN_ERR(in_tensor);
66   TensorC *out_tensor = self->out_[OUTPUT_INDEX];
67   NNACL_CHECK_NULL_RETURN_ERR(out_tensor);
68   CropParameter *crop_param = (CropParameter *)self->param_;
69   NNACL_CHECK_NULL_RETURN_ERR(crop_param);
70 
71   if (in_tensor->data_type_ != kNumberTypeFloat16 && out_tensor->shape_[Index1] < self->thread_nr_) {
72     float *input_data = (float *)in_tensor->data_;
73     NNACL_CHECK_NULL_RETURN_ERR(input_data);
74     float *output_data = (float *)out_tensor->data_;
75     NNACL_CHECK_NULL_RETURN_ERR(output_data);
76     Crop4DNoParallel(input_data, output_data, in_tensor->shape_, out_tensor->shape_, crop_param);
77     return NNACL_OK;
78   }
79 
80   return self->env_->ParallelLaunch(self->env_->thread_pool_, CropLaunch, self, self->thread_nr_);
81 }
82 
CreateCrop(OpParameter * param,int data_type)83 KernelBase *CreateCrop(OpParameter *param, int data_type) {
84   CropStruct *crop = (CropStruct *)malloc(sizeof(CropStruct));
85   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(crop);
86   memset(crop, 0, sizeof(CropStruct));
87   crop->base_.Prepare = DefaultPrepare1In1Out;
88   crop->base_.Resize = CropResize;
89   crop->base_.Release = DefaultRelease;
90   crop->base_.Compute = CropCompute;
91   return (KernelBase *)crop;
92 }
93 
94 REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeInt32, CreateCrop)
95 REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeFloat32, CreateCrop)
96 REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeFloat16, CreateCrop)
97