1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "nnacl/fp32/triu_tril_fp32.h"
17
TriuTrilGetCalculateNum(KernelBase * self,int64_t * mul,int64_t * height,int64_t * width)18 int TriuTrilGetCalculateNum(KernelBase *self, int64_t *mul, int64_t *height, int64_t *width) {
19 TensorC *input_tensor = self->in_[FIRST_INPUT];
20 NNACL_CHECK_NULL_RETURN_ERR(input_tensor);
21 for (size_t i = 0; i < input_tensor->shape_size_; i++) {
22 if (input_tensor->shape_[i] <= 0) {
23 return NNACL_TRIU_TRIL_INPUT_SHAPE_ERROR;
24 }
25 }
26
27 size_t input_hw_dims = Num2;
28 NNACL_CHECK_FALSE(input_tensor->shape_size_ < DIMENSION_2D, NNACL_TRIU_INPUT_DIMS_INVALID);
29
30 *mul = 1;
31 for (size_t i = 0; i < input_tensor->shape_size_ - input_hw_dims; i++) {
32 *mul *= input_tensor->shape_[i];
33 }
34 *height = input_tensor->shape_[input_tensor->shape_size_ - Num2];
35 *width = input_tensor->shape_[input_tensor->shape_size_ - Num1];
36
37 return NNACL_OK;
38 }
39
TriuTrilGetKValue(KernelBase * self,int64_t * k)40 int TriuTrilGetKValue(KernelBase *self, int64_t *k) {
41 if (self->in_size_ <= 1) {
42 *k = 0;
43 return NNACL_OK;
44 }
45
46 TensorC *k_tensor = self->in_[SECOND_INPUT];
47 NNACL_CHECK_NULL_RETURN_ERR(k_tensor);
48 NNACL_CHECK_NULL_RETURN_ERR(k_tensor->data_);
49
50 switch (k_tensor->data_type_) {
51 case kNumberTypeInt:
52 case kNumberTypeInt32:
53 *k = *((int32_t *)k_tensor->data_);
54 break;
55 case kNumberTypeInt64:
56 *k = *((int64_t *)k_tensor->data_);
57 break;
58 default:
59 return NNACL_TRIU_K_TENSOR_DATA_TYPE_INVALID;
60 }
61 return NNACL_OK;
62 }
63
TriuByte8(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)64 void TriuByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
65 const int64_t *src_data = (const int64_t *)src;
66 int64_t *dst_data = (int64_t *)dst;
67 for (int64_t m = 0; m < out_elems; m++) {
68 int64_t m_factor = m * height * width;
69 for (int64_t h = 0; h < height; h++) {
70 int64_t h_factor = m_factor + h * width;
71 for (int64_t w = 0; w < width; w++) {
72 int64_t index = h_factor + w;
73 dst_data[index] = h + k <= w ? src_data[index] : 0;
74 }
75 }
76 }
77 }
78
TriuByte4(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)79 void TriuByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
80 const int32_t *src_data = (const int32_t *)src;
81 int32_t *dst_data = (int32_t *)dst;
82 for (int64_t m = 0; m < out_elems; m++) {
83 int64_t m_factor = m * height * width;
84 for (int64_t h = 0; h < height; h++) {
85 int64_t h_factor = m_factor + h * width;
86 for (int64_t w = 0; w < width; w++) {
87 int64_t index = h_factor + w;
88 dst_data[index] = h + k <= w ? src_data[index] : 0;
89 }
90 }
91 }
92 }
93
TriuByte2(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)94 void TriuByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
95 const int16_t *src_data = (const int16_t *)src;
96 int16_t *dst_data = (int16_t *)dst;
97 for (int64_t m = 0; m < out_elems; m++) {
98 int64_t m_factor = m * height * width;
99 for (int64_t h = 0; h < height; h++) {
100 int64_t h_factor = m_factor + h * width;
101 for (int64_t w = 0; w < width; w++) {
102 int64_t index = h_factor + w;
103 dst_data[index] = h + k <= w ? src_data[index] : 0;
104 }
105 }
106 }
107 }
TriuByte1(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)108 void TriuByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
109 const int8_t *src_data = (const int8_t *)src;
110 int8_t *dst_data = (int8_t *)dst;
111 for (int64_t m = 0; m < out_elems; m++) {
112 int64_t m_factor = m * height * width;
113 for (int64_t h = 0; h < height; h++) {
114 int64_t h_factor = m_factor + h * width;
115 for (int64_t w = 0; w < width; w++) {
116 int64_t index = h_factor + w;
117 dst_data[index] = h + k <= w ? src_data[index] : 0;
118 }
119 }
120 }
121 }
122
TrilByte8(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)123 void TrilByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
124 const int64_t *src_data = (const int64_t *)src;
125 int64_t *dst_data = (int64_t *)dst;
126 for (int64_t m = 0; m < out_elems; m++) {
127 int64_t m_factor = m * height * width;
128 for (int64_t h = 0; h < height; h++) {
129 int64_t h_factor = m_factor + h * width;
130 for (int64_t w = 0; w < width; w++) {
131 int64_t index = h_factor + w;
132 dst_data[index] = h + k >= w ? src_data[index] : 0;
133 }
134 }
135 }
136 }
137
TrilByte4(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)138 void TrilByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
139 const int32_t *src_data = (const int32_t *)src;
140 int32_t *dst_data = (int32_t *)dst;
141 for (int64_t m = 0; m < out_elems; m++) {
142 int64_t m_factor = m * height * width;
143 for (int64_t h = 0; h < height; h++) {
144 int64_t h_factor = m_factor + h * width;
145 for (int64_t w = 0; w < width; w++) {
146 int64_t index = h_factor + w;
147 dst_data[index] = h + k >= w ? src_data[index] : 0;
148 }
149 }
150 }
151 }
TrilByte2(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)152 void TrilByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
153 const int16_t *src_data = (const int16_t *)src;
154 int16_t *dst_data = (int16_t *)dst;
155 for (int64_t m = 0; m < out_elems; m++) {
156 int64_t m_factor = m * height * width;
157 for (int64_t h = 0; h < height; h++) {
158 int64_t h_factor = m_factor + h * width;
159 for (int64_t w = 0; w < width; w++) {
160 int64_t index = h_factor + w;
161 dst_data[index] = h + k >= w ? src_data[index] : 0;
162 }
163 }
164 }
165 }
TrilByte1(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)166 void TrilByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
167 const int8_t *src_data = (const int8_t *)src;
168 int8_t *dst_data = (int8_t *)dst;
169 for (int64_t m = 0; m < out_elems; m++) {
170 int64_t m_factor = m * height * width;
171 for (int64_t h = 0; h < height; h++) {
172 int64_t h_factor = m_factor + h * width;
173 for (int64_t w = 0; w < width; w++) {
174 int64_t index = h_factor + w;
175 dst_data[index] = h + k >= w ? src_data[index] : 0;
176 }
177 }
178 }
179 }
180