• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/tril.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/common_func.h"
20 #include "nnacl/fp32/triu_tril_fp32.h"
21 
TrilCompute(KernelBase * self)22 int TrilCompute(KernelBase *self) {
23   TrilStruct *tril = (TrilStruct *)self;
24   NNACL_CHECK_NULL_RETURN_ERR(tril);
25 
26   int ret = TriuTrilGetKValue(self, &tril->k_);
27   if (ret != NNACL_OK) {
28     return ret;
29   }
30 
31   int64_t mul, height, width;
32   ret = TriuTrilGetCalculateNum(self, &mul, &height, &width);
33   if (ret != NNACL_OK) {
34     return ret;
35   }
36 
37   void *src_data = self->in_[FIRST_INPUT]->data_;
38   void *dst_data = self->out_[OUTPUT_INDEX]->data_;
39   int type_size = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_);
40   NNACL_CHECK_ZERO_RETURN_ERR(type_size);
41 
42   switch (type_size) {
43     case sizeof(int64_t): {
44       TrilByte8(src_data, dst_data, tril->k_, height, width, mul);
45       break;
46     }
47     case sizeof(int32_t): {
48       TrilByte4(src_data, dst_data, tril->k_, height, width, mul);
49       break;
50     }
51     case sizeof(int16_t): {
52       TrilByte2(src_data, dst_data, tril->k_, height, width, mul);
53       break;
54     }
55     case sizeof(int8_t): {
56       TrilByte1(src_data, dst_data, tril->k_, height, width, mul);
57       break;
58     }
59     default:
60       return NNACL_UNSUPPORTED_DATA_TYPE;
61   }
62   return NNACL_OK;
63 }
64 
CreateTril(OpParameter * param,int data_type)65 KernelBase *CreateTril(OpParameter *param, int data_type) {
66   TrilStruct *tril = (TrilStruct *)malloc(sizeof(TrilStruct));
67   NNACL_CHECK_NULL_RETURN_NULL(tril);
68   tril->base_.Release = DefaultRelease;
69   tril->base_.Prepare = DefaultPrepare1In1Out;
70   tril->base_.Resize = DefaultResize;
71   tril->base_.Compute = TrilCompute;
72   return (KernelBase *)tril;
73 }
74 
75 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeDouble, CreateTril)
76 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat, CreateTril)
77 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat64, CreateTril)
78 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat32, CreateTril)
79 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat16, CreateTril)
80 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt, CreateTril)
81 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt64, CreateTril)
82 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt32, CreateTril)
83 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt16, CreateTril)
84 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt8, CreateTril)
85 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt64, CreateTril)
86 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt32, CreateTril)
87 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt16, CreateTril)
88 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt8, CreateTril)
89 REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeBool, CreateTril)
90