• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/f16/concat_f16.h"
18 #include "nnacl/kernel/concat.h"
19 #include "nnacl/fp16/cast_fp16.h"
20 #include "nnacl/fp16/utils_fp16.h"
21 #include "nnacl/tensor_c_utils.h"
22 
23 typedef struct ConcatF16Struct {
24   ConcatStruct concat_;
25   void **tmp_buffer_; /* in_size + out_size */
26 } ConcatF16Struct;
27 
ConcatEnsureFp16InputsAndOutput(ConcatF16Struct * concat_f16)28 int ConcatEnsureFp16InputsAndOutput(ConcatF16Struct *concat_f16) {
29   ConcatStruct *concat = &concat_f16->concat_;
30 
31   int tmp_buffer_size = (concat->base_.in_size_ + concat->base_.out_size_) * sizeof(float16_t *);
32   concat_f16->tmp_buffer_ = concat->base_.env_->Alloc(concat->base_.env_->allocator_, tmp_buffer_size);
33   NNACL_CHECK_NULL_RETURN_ERR(concat_f16->tmp_buffer_);
34   memset(concat_f16->tmp_buffer_, 0, tmp_buffer_size);
35 
36   for (size_t i = 0; i < concat->base_.in_size_; ++i) {
37     if (!concat->is_with_data_[i]) {
38       continue;
39     }
40 
41     concat->inputs_ptr_[i] = GetOrAllocFp16Data(concat->base_.in_[i], concat->base_.env_, true);
42     NNACL_MALLOC_CHECK_NULL_RETURN_ERR(concat->inputs_ptr_[i]);
43     if (concat->base_.in_[i]->data_type_ == kNumberTypeFloat32 ||
44         concat->base_.in_[i]->data_type_ == kNumberTypeFloat) {
45       concat_f16->tmp_buffer_[i] = concat->inputs_ptr_[i];
46     }
47   }
48 
49   concat->output_ = GetOrAllocFp16Data(concat->base_.out_[OUTPUT_INDEX], concat->base_.env_, false);
50   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(concat->output_);
51   if (concat->base_.out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat32 ||
52       concat->base_.out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat) {
53     concat_f16->tmp_buffer_[concat->base_.in_size_] = concat->output_;
54   }
55   return NNACL_OK;
56 }
57 
ConcatFp16Run(void * cdata,int task_id,float l,float r)58 int ConcatFp16Run(void *cdata, int task_id, float l, float r) {
59   ConcatF16Struct *concat_f16 = (ConcatF16Struct *)cdata;
60   NNACL_CHECK_NULL_RETURN_ERR(concat_f16);
61   ConcatStruct *concat = &concat_f16->concat_;
62   return DoConcat(concat, task_id);
63 }
64 
ConcatF16FreeTmpBuffer(ConcatF16Struct * concat_f16)65 void ConcatF16FreeTmpBuffer(ConcatF16Struct *concat_f16) {
66   if (concat_f16->tmp_buffer_ != NULL) {
67     /* free tmp_buffer_[i] */
68     for (int i = 0; i < (concat_f16->concat_.base_.in_size_ + concat_f16->concat_.base_.out_size_); i++) {
69       if (concat_f16->tmp_buffer_[i] != NULL) {
70         concat_f16->concat_.base_.env_->Free(concat_f16->concat_.base_.env_->allocator_, concat_f16->tmp_buffer_[i]);
71       }
72       concat_f16->tmp_buffer_[i] = NULL;
73     }
74 
75     /* free tmp_buffer_ */
76     concat_f16->concat_.base_.env_->Free(concat_f16->concat_.base_.env_->allocator_, concat_f16->tmp_buffer_);
77     concat_f16->tmp_buffer_ = NULL;
78   }
79 }
80 
ConcatF16Compute(KernelBase * self)81 int ConcatF16Compute(KernelBase *self) {
82   ConcatF16Struct *concat_f16 = (ConcatF16Struct *)self;
83   NNACL_CHECK_NULL_RETURN_ERR(concat_f16);
84   ConcatStruct *concat = &concat_f16->concat_;
85 
86   if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) {
87     return NNACL_OK;
88   }
89 
90   int ret = ConcatEnsureFp16InputsAndOutput(concat_f16);
91   if (ret != NNACL_OK) {
92     ConcatF16FreeTmpBuffer(concat_f16);
93     return ret;
94   }
95 
96   NNACL_CHECK_NULL_RETURN_ERR(concat->output_);
97   ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConcatFp16Run, self, self->thread_nr_);
98   if (ret == NNACL_OK) {
99     TensorC *output_tensor = concat->base_.out_[FIRST_INPUT];
100     if (output_tensor->data_type_ == kNumberTypeFloat32 || output_tensor->data_type_ == kNumberTypeFloat) {
101       float *output = concat->base_.out_[FIRST_INPUT]->data_;
102       if (output == NULL) {
103         ret = NNACL_CONCAT_F16_OUTPUT_DATA_INVALID;
104       } else {
105         Float16ToFloat32((float16_t *)concat->output_, output, GetElementNum(output_tensor));
106       }
107     }
108   }
109 
110   ConcatF16FreeTmpBuffer(concat_f16);
111   return ret;
112 }
113 
CreateConcatF16(OpParameter * param,int data_type)114 KernelBase *CreateConcatF16(OpParameter *param, int data_type) {
115   ConcatF16Struct *concat_f16 = (ConcatF16Struct *)malloc(sizeof(ConcatF16Struct));
116   NNACL_CHECK_NULL_RETURN_NULL(concat_f16);
117   memset(concat_f16, 0, sizeof(ConcatF16Struct));
118 
119   ConcatStruct *concat = &concat_f16->concat_;
120   concat->data_type_ = kNumberTypeFloat16;
121   concat->inner_sizes_ = NULL;
122   concat->inputs_ptr_ = NULL;
123   concat->is_with_data_ = NULL;
124   concat->base_.Prepare = ConcatPepare;
125   concat->base_.Resize = ConcatResize;
126   concat->base_.Release = ConcatRelease;
127   concat->base_.Compute = ConcatF16Compute;
128   concat_f16->tmp_buffer_ = NULL;
129   return (KernelBase *)concat;
130 }
131 
132 REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeFloat16, CreateConcatF16)
133