1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/fill.h"
18 #include "nnacl/fill_parameter.h"
19 #include "nnacl/op_base.h"
20 #include "nnacl/nnacl_common.h"
21 #include "nnacl/tensor_c_utils.h"
22 #include "nnacl/base/fill_base.h"
23 #include "nnacl/kernel/default_kernel_base.h"
24 #ifdef ENABLE_FP16
25 #include "nnacl/fp16/fill_fp16.h"
26 #endif
27
FillResize(struct KernelBase * self)28 int FillResize(struct KernelBase *self) {
29 FillStruct *fill = (FillStruct *)self;
30 NNACL_CHECK_NULL_RETURN_ERR(fill);
31 fill->base_.thread_nr_ = fill->base_.UpdateThread(TC_PTYPE(PrimType_Fill), 0, 1,
32 GetSize(fill->base_.out_[OUTPUT_INDEX]), fill->base_.thread_nr_);
33
34 NNACL_CHECK_NULL_RETURN_ERR(fill->base_.out_[OUTPUT_INDEX]);
35 fill->data_size_ = (int)GetElementNum(fill->base_.out_[OUTPUT_INDEX]);
36 fill->thread_sz_count_ = MSMIN(fill->base_.thread_nr_, fill->data_size_);
37 if (fill->thread_sz_count_ != 0) {
38 fill->thread_sz_stride_ = UP_DIV(fill->data_size_, fill->thread_sz_count_);
39 }
40 return NNACL_OK;
41 }
42
FillImpl(void * cdata,int task_id,float l,float r)43 int FillImpl(void *cdata, int task_id, float l, float r) {
44 FillStruct *fill = (FillStruct *)cdata;
45 NNACL_CHECK_NULL_RETURN_ERR(fill);
46 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, fill->thread_sz_stride_, NNACL_ERR);
47 int size = MSMIN(fill->thread_sz_stride_, fill->data_size_ - task_id * fill->thread_sz_stride_);
48 NNACL_CHECK_FALSE(size <= 0, NNACL_OK);
49 int offset = task_id * fill->thread_sz_stride_;
50 int ret = NNACL_OK;
51 switch (fill->base_.in_[FIRST_INPUT]->data_type_) {
52 #ifdef ENABLE_FP16
53 case kNumberTypeFloat16:
54 ret = FillFp16((float16_t *)fill->out_ptr_ + offset, size, ((float16_t *)fill->src_data_)[FIRST_INPUT]);
55 break;
56 #endif
57 case kNumberTypeFloat32:
58 ret = FillFp32((float *)fill->out_ptr_ + offset, size, ((float *)fill->src_data_)[FIRST_INPUT]);
59 break;
60 case kNumberTypeInt32:
61 ret = FillInt32((int *)fill->out_ptr_ + offset, size, ((int *)fill->src_data_)[FIRST_INPUT]);
62 break;
63 case kNumberTypeBool:
64 ret = FillBool((bool *)fill->out_ptr_ + offset, size, ((bool *)fill->src_data_)[FIRST_INPUT]);
65 break;
66 default:
67 return NNACL_FILL_DATA_TYPE_INVALID;
68 }
69 return ret;
70 }
71
FillCompute(struct KernelBase * self)72 int FillCompute(struct KernelBase *self) {
73 FillStruct *fill = (FillStruct *)self;
74 NNACL_CHECK_NULL_RETURN_ERR(fill);
75
76 fill->src_data_ = (void *)fill->base_.in_[FIRST_INPUT]->data_;
77 NNACL_CHECK_NULL_RETURN_ERR(fill->src_data_);
78 fill->out_ptr_ = (void *)fill->base_.out_[OUTPUT_INDEX]->data_;
79 NNACL_CHECK_NULL_RETURN_ERR(fill->out_ptr_);
80
81 return self->env_->ParallelLaunch(self->env_->thread_pool_, FillImpl, fill, fill->base_.thread_nr_);
82 }
83
CreateFill(OpParameter * param,int data_type)84 KernelBase *CreateFill(OpParameter *param, int data_type) {
85 FillStruct *fill = (FillStruct *)malloc(sizeof(FillStruct));
86 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fill);
87 fill->base_.Prepare = DefaultPrepare2In1Out;
88 fill->base_.Resize = FillResize;
89 fill->base_.Release = DefaultRelease;
90 fill->base_.Compute = FillCompute;
91 return (KernelBase *)fill;
92 }
93
94 REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeBool, CreateFill);
95 REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeInt32, CreateFill);
96 REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeFloat32, CreateFill);
97 REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeFloat16, CreateFill);
98
99 REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeBool, CreateFill);
100 REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeInt32, CreateFill);
101 REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeFloat32, CreateFill);
102 REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeFloat16, CreateFill);
103