• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/fill.h"
18 #include "nnacl/fill_parameter.h"
19 #include "nnacl/op_base.h"
20 #include "nnacl/nnacl_common.h"
21 #include "nnacl/tensor_c_utils.h"
22 #include "nnacl/base/fill_base.h"
23 #include "nnacl/kernel/default_kernel_base.h"
24 #ifdef ENABLE_FP16
25 #include "nnacl/fp16/fill_fp16.h"
26 #endif
27 
FillResize(struct KernelBase * self)28 int FillResize(struct KernelBase *self) {
29   FillStruct *fill = (FillStruct *)self;
30   NNACL_CHECK_NULL_RETURN_ERR(fill);
31   fill->base_.thread_nr_ = fill->base_.UpdateThread(TC_PTYPE(PrimType_Fill), 0, 1,
32                                                     GetSize(fill->base_.out_[OUTPUT_INDEX]), fill->base_.thread_nr_);
33 
34   NNACL_CHECK_NULL_RETURN_ERR(fill->base_.out_[OUTPUT_INDEX]);
35   fill->data_size_ = (int)GetElementNum(fill->base_.out_[OUTPUT_INDEX]);
36   fill->thread_sz_count_ = MSMIN(fill->base_.thread_nr_, fill->data_size_);
37   if (fill->thread_sz_count_ != 0) {
38     fill->thread_sz_stride_ = UP_DIV(fill->data_size_, fill->thread_sz_count_);
39   }
40   return NNACL_OK;
41 }
42 
FillImpl(void * cdata,int task_id,float l,float r)43 int FillImpl(void *cdata, int task_id, float l, float r) {
44   FillStruct *fill = (FillStruct *)cdata;
45   NNACL_CHECK_NULL_RETURN_ERR(fill);
46   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, fill->thread_sz_stride_, NNACL_ERR);
47   int size = MSMIN(fill->thread_sz_stride_, fill->data_size_ - task_id * fill->thread_sz_stride_);
48   NNACL_CHECK_FALSE(size <= 0, NNACL_OK);
49   int offset = task_id * fill->thread_sz_stride_;
50   int ret = NNACL_OK;
51   switch (fill->base_.in_[FIRST_INPUT]->data_type_) {
52 #ifdef ENABLE_FP16
53     case kNumberTypeFloat16:
54       ret = FillFp16((float16_t *)fill->out_ptr_ + offset, size, ((float16_t *)fill->src_data_)[FIRST_INPUT]);
55       break;
56 #endif
57     case kNumberTypeFloat32:
58       ret = FillFp32((float *)fill->out_ptr_ + offset, size, ((float *)fill->src_data_)[FIRST_INPUT]);
59       break;
60     case kNumberTypeInt32:
61       ret = FillInt32((int *)fill->out_ptr_ + offset, size, ((int *)fill->src_data_)[FIRST_INPUT]);
62       break;
63     case kNumberTypeBool:
64       ret = FillBool((bool *)fill->out_ptr_ + offset, size, ((bool *)fill->src_data_)[FIRST_INPUT]);
65       break;
66     default:
67       return NNACL_FILL_DATA_TYPE_INVALID;
68   }
69   return ret;
70 }
71 
FillCompute(struct KernelBase * self)72 int FillCompute(struct KernelBase *self) {
73   FillStruct *fill = (FillStruct *)self;
74   NNACL_CHECK_NULL_RETURN_ERR(fill);
75 
76   fill->src_data_ = (void *)fill->base_.in_[FIRST_INPUT]->data_;
77   NNACL_CHECK_NULL_RETURN_ERR(fill->src_data_);
78   fill->out_ptr_ = (void *)fill->base_.out_[OUTPUT_INDEX]->data_;
79   NNACL_CHECK_NULL_RETURN_ERR(fill->out_ptr_);
80 
81   return self->env_->ParallelLaunch(self->env_->thread_pool_, FillImpl, fill, fill->base_.thread_nr_);
82 }
83 
CreateFill(OpParameter * param,int data_type)84 KernelBase *CreateFill(OpParameter *param, int data_type) {
85   FillStruct *fill = (FillStruct *)malloc(sizeof(FillStruct));
86   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fill);
87   fill->base_.Prepare = DefaultPrepare2In1Out;
88   fill->base_.Resize = FillResize;
89   fill->base_.Release = DefaultRelease;
90   fill->base_.Compute = FillCompute;
91   return (KernelBase *)fill;
92 }
93 
94 REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeBool, CreateFill);
95 REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeInt32, CreateFill);
96 REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeFloat32, CreateFill);
97 REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeFloat16, CreateFill);
98 
99 REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeBool, CreateFill);
100 REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeInt32, CreateFill);
101 REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeFloat32, CreateFill);
102 REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeFloat16, CreateFill);
103