1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/f16/stack_f16.h"
18 #include "nnacl/fp16/cast_fp16.h"
19 #include "nnacl/tensor_c_utils.h"
20
StackF16InitBuffer(KernelBase * base,TensorC * t,bool init)21 void *StackF16InitBuffer(KernelBase *base, TensorC *t, bool init) {
22 if (init == false) {
23 return t->data_;
24 }
25
26 int ele_num = GetElementNum(t);
27 void *f16_buffer = base->env_->Alloc(base->env_->allocator_, ele_num * sizeof(float16_t));
28 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(f16_buffer);
29 Float32ToFloat16(t->data_, f16_buffer, ele_num);
30 return f16_buffer;
31 }
32
StackF16InitMallocFlags(StackF16Struct * stack_f16)33 int StackF16InitMallocFlags(StackF16Struct *stack_f16) {
34 KernelBase *base = (KernelBase *)stack_f16;
35 stack_f16->init_ = base->env_->Alloc(base->env_->allocator_, (base->in_size_ + base->out_size_) * sizeof(bool));
36 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->init_);
37
38 for (size_t i = 0; i < base->in_size_; ++i) {
39 stack_f16->init_[i] = base->in_[i]->data_type_ == kNumberTypeFloat32;
40 stack_f16->stack_.buffers_[i] = StackF16InitBuffer(base, base->in_[i], stack_f16->init_[i]);
41 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->stack_.buffers_[i]);
42 }
43 stack_f16->init_[base->in_size_] = base->out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat32;
44 stack_f16->stack_.buffers_[base->in_size_] =
45 StackF16InitBuffer(base, base->out_[OUTPUT_INDEX], stack_f16->init_[base->in_size_]);
46 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->stack_.buffers_[base->in_size_]);
47 return NNACL_OK;
48 }
49
StackF16FreeBuffer(StackF16Struct * stack_f16)50 void StackF16FreeBuffer(StackF16Struct *stack_f16) {
51 if (stack_f16->init_[stack_f16->stack_.base_.in_size_]) {
52 /* output transfer */
53 Float16ToFloat32((float16_t *)stack_f16->stack_.buffers_[stack_f16->stack_.base_.in_size_],
54 (float *)stack_f16->stack_.base_.out_[OUTPUT_INDEX]->data_,
55 GetElementNum(stack_f16->stack_.base_.out_[OUTPUT_INDEX]));
56 }
57
58 for (size_t i = 0; i < (stack_f16->stack_.base_.in_size_ + stack_f16->stack_.base_.out_size_); ++i) {
59 if (stack_f16->init_[i]) {
60 stack_f16->stack_.base_.env_->Free(stack_f16->stack_.base_.env_->allocator_, stack_f16->stack_.buffers_[i]);
61 }
62 stack_f16->stack_.buffers_[i] = NULL;
63 }
64
65 stack_f16->stack_.base_.env_->Free(stack_f16->stack_.base_.env_->allocator_, stack_f16->init_);
66 stack_f16->init_ = NULL;
67 }
68
StackF16Compute(KernelBase * self)69 int StackF16Compute(KernelBase *self) {
70 StackF16Struct *stack_f16 = (StackF16Struct *)self;
71 NNACL_CHECK_NULL_RETURN_ERR(stack_f16);
72
73 int ret = StackF16InitMallocFlags(stack_f16);
74 if (ret != NNACL_OK) {
75 return ret;
76 }
77
78 ret = self->env_->ParallelLaunch(self->env_->thread_pool_, StackRun, self, self->thread_nr_);
79 StackF16FreeBuffer(stack_f16);
80 return ret;
81 }
82
CreateStackF16(OpParameter * param,int data_type)83 KernelBase *CreateStackF16(OpParameter *param, int data_type) {
84 StackF16Struct *stack_f16 = (StackF16Struct *)malloc(sizeof(StackF16Struct));
85 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(stack_f16);
86 StackStruct *stack = &stack_f16->stack_;
87 stack->buffers_ = NULL;
88 stack->data_type_ = data_type;
89 stack->base_.Release = StackRelease;
90 stack->base_.Prepare = StackPrepare;
91 stack->base_.Resize = StackResize;
92 stack->base_.Compute = StackF16Compute;
93 return (KernelBase *)stack;
94 }
95
96 REG_KERNEL_CREATOR(PrimType_Stack, kNumberTypeFloat16, CreateStackF16)
97