• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/fp32/triu_tril_fp32.h"
17 
TriuTrilGetCalculateNum(KernelBase * self,int64_t * mul,int64_t * height,int64_t * width)18 int TriuTrilGetCalculateNum(KernelBase *self, int64_t *mul, int64_t *height, int64_t *width) {
19   TensorC *input_tensor = self->in_[FIRST_INPUT];
20   NNACL_CHECK_NULL_RETURN_ERR(input_tensor);
21   for (size_t i = 0; i < input_tensor->shape_size_; i++) {
22     if (input_tensor->shape_[i] <= 0) {
23       return NNACL_TRIU_TRIL_INPUT_SHAPE_ERROR;
24     }
25   }
26 
27   size_t input_hw_dims = Num2;
28   NNACL_CHECK_FALSE(input_tensor->shape_size_ < DIMENSION_2D, NNACL_TRIU_INPUT_DIMS_INVALID);
29 
30   *mul = 1;
31   for (size_t i = 0; i < input_tensor->shape_size_ - input_hw_dims; i++) {
32     *mul *= input_tensor->shape_[i];
33   }
34   *height = input_tensor->shape_[input_tensor->shape_size_ - Num2];
35   *width = input_tensor->shape_[input_tensor->shape_size_ - Num1];
36 
37   return NNACL_OK;
38 }
39 
TriuTrilGetKValue(KernelBase * self,int64_t * k)40 int TriuTrilGetKValue(KernelBase *self, int64_t *k) {
41   if (self->in_size_ <= 1) {
42     *k = 0;
43     return NNACL_OK;
44   }
45 
46   TensorC *k_tensor = self->in_[SECOND_INPUT];
47   NNACL_CHECK_NULL_RETURN_ERR(k_tensor);
48   NNACL_CHECK_NULL_RETURN_ERR(k_tensor->data_);
49 
50   switch (k_tensor->data_type_) {
51     case kNumberTypeInt:
52     case kNumberTypeInt32:
53       *k = *((int32_t *)k_tensor->data_);
54       break;
55     case kNumberTypeInt64:
56       *k = *((int64_t *)k_tensor->data_);
57       break;
58     default:
59       return NNACL_TRIU_K_TENSOR_DATA_TYPE_INVALID;
60   }
61   return NNACL_OK;
62 }
63 
TriuByte8(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)64 void TriuByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
65   const int64_t *src_data = (const int64_t *)src;
66   int64_t *dst_data = (int64_t *)dst;
67   for (int64_t m = 0; m < out_elems; m++) {
68     int64_t m_factor = m * height * width;
69     for (int64_t h = 0; h < height; h++) {
70       int64_t h_factor = m_factor + h * width;
71       for (int64_t w = 0; w < width; w++) {
72         int64_t index = h_factor + w;
73         dst_data[index] = h + k <= w ? src_data[index] : 0;
74       }
75     }
76   }
77 }
78 
TriuByte4(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)79 void TriuByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
80   const int32_t *src_data = (const int32_t *)src;
81   int32_t *dst_data = (int32_t *)dst;
82   for (int64_t m = 0; m < out_elems; m++) {
83     int64_t m_factor = m * height * width;
84     for (int64_t h = 0; h < height; h++) {
85       int64_t h_factor = m_factor + h * width;
86       for (int64_t w = 0; w < width; w++) {
87         int64_t index = h_factor + w;
88         dst_data[index] = h + k <= w ? src_data[index] : 0;
89       }
90     }
91   }
92 }
93 
TriuByte2(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)94 void TriuByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
95   const int16_t *src_data = (const int16_t *)src;
96   int16_t *dst_data = (int16_t *)dst;
97   for (int64_t m = 0; m < out_elems; m++) {
98     int64_t m_factor = m * height * width;
99     for (int64_t h = 0; h < height; h++) {
100       int64_t h_factor = m_factor + h * width;
101       for (int64_t w = 0; w < width; w++) {
102         int64_t index = h_factor + w;
103         dst_data[index] = h + k <= w ? src_data[index] : 0;
104       }
105     }
106   }
107 }
TriuByte1(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)108 void TriuByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
109   const int8_t *src_data = (const int8_t *)src;
110   int8_t *dst_data = (int8_t *)dst;
111   for (int64_t m = 0; m < out_elems; m++) {
112     int64_t m_factor = m * height * width;
113     for (int64_t h = 0; h < height; h++) {
114       int64_t h_factor = m_factor + h * width;
115       for (int64_t w = 0; w < width; w++) {
116         int64_t index = h_factor + w;
117         dst_data[index] = h + k <= w ? src_data[index] : 0;
118       }
119     }
120   }
121 }
122 
TrilByte8(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)123 void TrilByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
124   const int64_t *src_data = (const int64_t *)src;
125   int64_t *dst_data = (int64_t *)dst;
126   for (int64_t m = 0; m < out_elems; m++) {
127     int64_t m_factor = m * height * width;
128     for (int64_t h = 0; h < height; h++) {
129       int64_t h_factor = m_factor + h * width;
130       for (int64_t w = 0; w < width; w++) {
131         int64_t index = h_factor + w;
132         dst_data[index] = h + k >= w ? src_data[index] : 0;
133       }
134     }
135   }
136 }
137 
TrilByte4(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)138 void TrilByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
139   const int32_t *src_data = (const int32_t *)src;
140   int32_t *dst_data = (int32_t *)dst;
141   for (int64_t m = 0; m < out_elems; m++) {
142     int64_t m_factor = m * height * width;
143     for (int64_t h = 0; h < height; h++) {
144       int64_t h_factor = m_factor + h * width;
145       for (int64_t w = 0; w < width; w++) {
146         int64_t index = h_factor + w;
147         dst_data[index] = h + k >= w ? src_data[index] : 0;
148       }
149     }
150   }
151 }
TrilByte2(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)152 void TrilByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
153   const int16_t *src_data = (const int16_t *)src;
154   int16_t *dst_data = (int16_t *)dst;
155   for (int64_t m = 0; m < out_elems; m++) {
156     int64_t m_factor = m * height * width;
157     for (int64_t h = 0; h < height; h++) {
158       int64_t h_factor = m_factor + h * width;
159       for (int64_t w = 0; w < width; w++) {
160         int64_t index = h_factor + w;
161         dst_data[index] = h + k >= w ? src_data[index] : 0;
162       }
163     }
164   }
165 }
TrilByte1(const void * src,void * dst,int64_t k,int64_t height,int64_t width,int64_t out_elems)166 void TrilByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) {
167   const int8_t *src_data = (const int8_t *)src;
168   int8_t *dst_data = (int8_t *)dst;
169   for (int64_t m = 0; m < out_elems; m++) {
170     int64_t m_factor = m * height * width;
171     for (int64_t h = 0; h < height; h++) {
172       int64_t h_factor = m_factor + h * width;
173       for (int64_t w = 0; w < width; w++) {
174         int64_t index = h_factor + w;
175         dst_data[index] = h + k >= w ? src_data[index] : 0;
176       }
177     }
178   }
179 }
180