1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/fp16/log_softmax_fp16.h"
18 #include <math.h>
19 #include <float.h>
20 #include "nnacl/fp16/softmax_fp16.h"
21 #include "nnacl/fp16/exp_fp16.h"
22
LogSoftmaxLastAxisFp16(const float16_t * src,float16_t * dst,float16_t * exp_data,int batch,int channel)23 void LogSoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, float16_t *exp_data, int batch, int channel) {
24 SoftmaxNormFp16(src, dst, batch, channel);
25 ExpFp16(dst, exp_data, batch * channel);
26 int cur_batch_offset = 0;
27 for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
28 float16_t sum = 0;
29 int j = 0;
30 #ifdef ENABLE_NEON
31 float16x8_t sum8 = vdupq_n_f16(0);
32 int count = (channel / C8NUM) * C8NUM;
33 for (; j < count; j += C8NUM) {
34 sum8 = vaddq_f16(sum8, vld1q_f16(exp_data + cur_batch_offset + j));
35 }
36 sum = sum8[0] + sum8[1] + sum8[2] + sum8[3] + sum8[4] + sum8[5] + sum8[6] + sum8[7];
37 #endif
38 for (; j < channel; j++) {
39 sum += exp_data[cur_batch_offset + j];
40 }
41 for (int k = 0; k < channel; k++) {
42 dst[cur_batch_offset + k] = dst[cur_batch_offset + k] - log(sum);
43 }
44 }
45 }
46
47 // output = (input - reduce_max(input, axis)) - log(reduce_sum(exp(input - reduce_max(input, axis)), axis))
LogSoftmaxFp16(const float16_t * input_ptr,float16_t * output_ptr,float16_t * sum_data,int * input_shape,int n_dim,int axis)48 void LogSoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int *input_shape, int n_dim,
49 int axis) {
50 int inner_size = 1;
51 int outter_size = 1;
52
53 for (int i = 0; i < axis; i++) {
54 outter_size *= input_shape[i];
55 }
56 for (int i = axis + 1; i < n_dim; i++) {
57 inner_size *= input_shape[i];
58 }
59 for (int i = 0; i < outter_size; i++) {
60 int outter_offset = i * input_shape[axis] * inner_size;
61 int sum_outter_offset = i * inner_size;
62 for (int k = 0; k < inner_size; k++) {
63 int inner_offset = outter_offset + k;
64 float16_t max_data = input_ptr[inner_offset];
65 sum_data[k + sum_outter_offset] = 0;
66 for (int j = 0; j < input_shape[axis]; j++) {
67 int axis_offset = inner_offset + j * inner_size;
68 max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset];
69 }
70 for (int j = 0; j < input_shape[axis]; j++) {
71 int axis_offset = inner_offset + j * inner_size;
72 output_ptr[axis_offset] = input_ptr[axis_offset] - max_data;
73 sum_data[k + sum_outter_offset] += exp(output_ptr[axis_offset]);
74 }
75 }
76 }
77 for (int i = 0; i < outter_size; i++) {
78 int outter_offset = i * input_shape[axis] * inner_size;
79 int sum_outter_offset = i * inner_size;
80 for (int j = 0; j < input_shape[axis]; j++) {
81 int axis_offset = outter_offset + j * inner_size;
82 for (int k = 0; k < inner_size; k++) {
83 int inner_offset = axis_offset + k;
84 output_ptr[inner_offset] = output_ptr[inner_offset] - log(sum_data[k + sum_outter_offset]);
85 }
86 }
87 }
88 }
89