1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/reverse.h"
18 #include "nnacl/tensor_c_utils.h"
19 #include "nnacl/reverse_parameter.h"
20 #include "nnacl/fp32/reverse_fp32.h"
21
ReverseStride(TensorC * input,int index)22 int ReverseStride(TensorC *input, int index) {
23 int stride = 1;
24 for (int i = index + 1; i < (int)input->shape_size_; i++) {
25 stride *= input->shape_[i];
26 }
27 return stride;
28 }
29
ReverseRun(void * cdata,int task_id,float l,float r)30 int ReverseRun(void *cdata, int task_id, float l, float r) {
31 ReverseStruct *reverse = (ReverseStruct *)cdata;
32 NNACL_CHECK_NULL_RETURN_ERR(reverse);
33
34 int offset = task_id * reverse->thread_stride_;
35 int count = NNACL_MIN(reverse->thread_stride_, reverse->data_size_ - offset);
36 if (count <= 0) {
37 return NNACL_OK;
38 }
39
40 float *in_ptr = (float *)reverse->base_.in_[FIRST_INPUT]->data_;
41 NNACL_CHECK_NULL_RETURN_ERR(in_ptr);
42 float *out_ptr = (float *)reverse->base_.out_[OUTPUT_INDEX]->data_;
43 NNACL_CHECK_NULL_RETURN_ERR(out_ptr);
44 return Reverse(in_ptr + offset, out_ptr, reverse->thread_stride_, reverse->tmp_ + offset);
45 }
46
ReverseUpdateAxisInfo(ReverseStruct * reverse)47 int ReverseUpdateAxisInfo(ReverseStruct *reverse) {
48 ReverseParameter *reverse_param = (ReverseParameter *)reverse->base_.param_;
49 int in_shape_len = reverse->base_.in_[FIRST_INPUT]->shape_size_;
50 for (int i = 0; i < reverse_param->num_axis_; ++i) {
51 if (reverse_param->axis_[i] < 0) {
52 reverse_param->axis_[i] += in_shape_len;
53 }
54 if (reverse_param->axis_[i] < 0 || reverse_param->axis_[i] >= in_shape_len) {
55 return NNACL_REVERSE_AXIS_VALUE_INVALID;
56 }
57 }
58 return NNACL_OK;
59 }
60
ReverseCompute(KernelBase * self)61 int ReverseCompute(KernelBase *self) {
62 return self->env_->ParallelLaunch(self->env_->thread_pool_, ReverseRun, self, self->thread_nr_);
63 }
64
ReversePrepare(KernelBase * self)65 int ReversePrepare(KernelBase *self) {
66 NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR);
67 NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR);
68 ReverseStruct *reverse = (ReverseStruct *)self;
69 NNACL_CHECK_NULL_RETURN_ERR(reverse);
70 if (((ReverseParameter *)self->param_)->num_axis_ < Num1) {
71 return NNACL_REVERSE_AXIS_INVALID;
72 }
73 return NNACL_OK;
74 }
75
ReverseRelease(KernelBase * self)76 int ReverseRelease(KernelBase *self) {
77 ReverseStruct *reverse = (ReverseStruct *)self;
78 NNACL_CHECK_NULL_RETURN_ERR(reverse);
79 if (reverse->tmp_ != NULL) {
80 self->env_->Free(self->env_->allocator_, reverse->tmp_);
81 reverse->tmp_ = NULL;
82 }
83 return NNACL_OK;
84 }
85
ReverseResize(KernelBase * self)86 int ReverseResize(KernelBase *self) {
87 ReverseStruct *reverse = (ReverseStruct *)self;
88 NNACL_CHECK_NULL_RETURN_ERR(reverse);
89
90 TensorC *input = self->in_[FIRST_INPUT];
91 NNACL_CHECK_NULL_RETURN_ERR(input);
92 TensorC *output = self->out_[OUTPUT_INDEX];
93 NNACL_CHECK_NULL_RETURN_ERR(output);
94
95 // trans negative to positive axis
96 int ret = ReverseUpdateAxisInfo(reverse);
97 if (ret != NNACL_OK) {
98 return ret;
99 }
100
101 reverse->data_size_ = GetElementNum(input);
102 if (GetElementNum(output) != reverse->data_size_) {
103 return NNACL_REVERSE_DATA_SIZE_INVALID;
104 }
105
106 self->thread_nr_ = NNACL_MIN(self->thread_nr_, reverse->data_size_);
107 NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_);
108 reverse->thread_stride_ = UP_DIV(reverse->data_size_, self->thread_nr_);
109
110 ReverseParameter *reverse_param = (ReverseParameter *)self->param_;
111 if (reverse_param->num_axis_ > input->shape_size_) {
112 return NNACL_REVERSE_NUM_AXIS_INVALID;
113 }
114 if (input->shape_size_ > REVERSE_SHAPE_MAX_SIZE) {
115 return NNACL_REVERSE_NUM_AXIS_INVALID;
116 }
117
118 (void)self->Release(self);
119 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(reverse->data_size_, sizeof(int), NNACL_ERR);
120 reverse->tmp_ = (int *)self->env_->Alloc(self->env_->allocator_, reverse->data_size_ * sizeof(int));
121 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(reverse->tmp_);
122 memset(reverse->tmp_, 0, reverse->data_size_ * sizeof(int));
123
124 for (int i = 0; i < reverse_param->num_axis_; i++) {
125 int axis = reverse_param->axis_[i];
126 int stride = ReverseStride(input, axis);
127 reverse->strides_[i] = stride;
128 reverse->in_count_[i] = input->shape_[axis];
129 reverse->out_count_[i] = 1;
130 for (int j = 0; j < axis; j++) {
131 reverse->out_count_[i] *= input->shape_[j];
132 }
133 }
134
135 int out;
136 int in;
137 int C;
138 int m;
139 for (int i = 0; i < reverse->data_size_; ++i) {
140 int tmp = i;
141 for (int j = 0; j < reverse_param->num_axis_; ++j) {
142 C = reverse->in_count_[j];
143 out = tmp / (C * reverse->strides_[j]);
144 in = tmp / reverse->strides_[j] - out * C;
145 m = tmp % reverse->strides_[j];
146 tmp = out * C * reverse->strides_[j] + reverse->strides_[j] * (C - 1 - in) + m;
147 }
148 reverse->tmp_[i] = tmp;
149 }
150
151 return NNACL_OK;
152 }
153
CreateReverse(OpParameter * param,int data_type)154 KernelBase *CreateReverse(OpParameter *param, int data_type) {
155 ReverseStruct *reverse = (ReverseStruct *)malloc(sizeof(ReverseStruct));
156 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reverse);
157 memset(reverse, 0, sizeof(ReverseStruct));
158 reverse->base_.Release = ReverseRelease;
159 reverse->base_.Prepare = ReversePrepare;
160 reverse->base_.Resize = ReverseResize;
161 reverse->base_.Compute = ReverseCompute;
162 return (KernelBase *)reverse;
163 }
164
165 REG_KERNEL_CREATOR(PrimType_ReverseV2, kNumberTypeFloat32, CreateReverse)
166 REG_KERNEL_CREATOR(PrimType_ReverseV2, kNumberTypeInt32, CreateReverse)
167