• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/reverse.h"
18 #include "nnacl/tensor_c_utils.h"
19 #include "nnacl/reverse_parameter.h"
20 #include "nnacl/fp32/reverse_fp32.h"
21 
ReverseStride(TensorC * input,int index)22 int ReverseStride(TensorC *input, int index) {
23   int stride = 1;
24   for (int i = index + 1; i < (int)input->shape_size_; i++) {
25     stride *= input->shape_[i];
26   }
27   return stride;
28 }
29 
ReverseRun(void * cdata,int task_id,float l,float r)30 int ReverseRun(void *cdata, int task_id, float l, float r) {
31   ReverseStruct *reverse = (ReverseStruct *)cdata;
32   NNACL_CHECK_NULL_RETURN_ERR(reverse);
33 
34   int offset = task_id * reverse->thread_stride_;
35   int count = NNACL_MIN(reverse->thread_stride_, reverse->data_size_ - offset);
36   if (count <= 0) {
37     return NNACL_OK;
38   }
39 
40   float *in_ptr = (float *)reverse->base_.in_[FIRST_INPUT]->data_;
41   NNACL_CHECK_NULL_RETURN_ERR(in_ptr);
42   float *out_ptr = (float *)reverse->base_.out_[OUTPUT_INDEX]->data_;
43   NNACL_CHECK_NULL_RETURN_ERR(out_ptr);
44   return Reverse(in_ptr + offset, out_ptr, reverse->thread_stride_, reverse->tmp_ + offset);
45 }
46 
ReverseUpdateAxisInfo(ReverseStruct * reverse)47 int ReverseUpdateAxisInfo(ReverseStruct *reverse) {
48   ReverseParameter *reverse_param = (ReverseParameter *)reverse->base_.param_;
49   int in_shape_len = reverse->base_.in_[FIRST_INPUT]->shape_size_;
50   for (int i = 0; i < reverse_param->num_axis_; ++i) {
51     if (reverse_param->axis_[i] < 0) {
52       reverse_param->axis_[i] += in_shape_len;
53     }
54     if (reverse_param->axis_[i] < 0 || reverse_param->axis_[i] >= in_shape_len) {
55       return NNACL_REVERSE_AXIS_VALUE_INVALID;
56     }
57   }
58   return NNACL_OK;
59 }
60 
ReverseCompute(KernelBase * self)61 int ReverseCompute(KernelBase *self) {
62   return self->env_->ParallelLaunch(self->env_->thread_pool_, ReverseRun, self, self->thread_nr_);
63 }
64 
ReversePrepare(KernelBase * self)65 int ReversePrepare(KernelBase *self) {
66   NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR);
67   NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR);
68   ReverseStruct *reverse = (ReverseStruct *)self;
69   NNACL_CHECK_NULL_RETURN_ERR(reverse);
70   if (((ReverseParameter *)self->param_)->num_axis_ < Num1) {
71     return NNACL_REVERSE_AXIS_INVALID;
72   }
73   return NNACL_OK;
74 }
75 
ReverseRelease(KernelBase * self)76 int ReverseRelease(KernelBase *self) {
77   ReverseStruct *reverse = (ReverseStruct *)self;
78   NNACL_CHECK_NULL_RETURN_ERR(reverse);
79   if (reverse->tmp_ != NULL) {
80     self->env_->Free(self->env_->allocator_, reverse->tmp_);
81     reverse->tmp_ = NULL;
82   }
83   return NNACL_OK;
84 }
85 
ReverseResize(KernelBase * self)86 int ReverseResize(KernelBase *self) {
87   ReverseStruct *reverse = (ReverseStruct *)self;
88   NNACL_CHECK_NULL_RETURN_ERR(reverse);
89 
90   TensorC *input = self->in_[FIRST_INPUT];
91   NNACL_CHECK_NULL_RETURN_ERR(input);
92   TensorC *output = self->out_[OUTPUT_INDEX];
93   NNACL_CHECK_NULL_RETURN_ERR(output);
94 
95   // trans negative to positive axis
96   int ret = ReverseUpdateAxisInfo(reverse);
97   if (ret != NNACL_OK) {
98     return ret;
99   }
100 
101   reverse->data_size_ = GetElementNum(input);
102   if (GetElementNum(output) != reverse->data_size_) {
103     return NNACL_REVERSE_DATA_SIZE_INVALID;
104   }
105 
106   self->thread_nr_ = NNACL_MIN(self->thread_nr_, reverse->data_size_);
107   NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_);
108   reverse->thread_stride_ = UP_DIV(reverse->data_size_, self->thread_nr_);
109 
110   ReverseParameter *reverse_param = (ReverseParameter *)self->param_;
111   if (reverse_param->num_axis_ > input->shape_size_) {
112     return NNACL_REVERSE_NUM_AXIS_INVALID;
113   }
114   if (input->shape_size_ > REVERSE_SHAPE_MAX_SIZE) {
115     return NNACL_REVERSE_NUM_AXIS_INVALID;
116   }
117 
118   (void)self->Release(self);
119   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(reverse->data_size_, sizeof(int), NNACL_ERR);
120   reverse->tmp_ = (int *)self->env_->Alloc(self->env_->allocator_, reverse->data_size_ * sizeof(int));
121   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(reverse->tmp_);
122   memset(reverse->tmp_, 0, reverse->data_size_ * sizeof(int));
123 
124   for (int i = 0; i < reverse_param->num_axis_; i++) {
125     int axis = reverse_param->axis_[i];
126     int stride = ReverseStride(input, axis);
127     reverse->strides_[i] = stride;
128     reverse->in_count_[i] = input->shape_[axis];
129     reverse->out_count_[i] = 1;
130     for (int j = 0; j < axis; j++) {
131       reverse->out_count_[i] *= input->shape_[j];
132     }
133   }
134 
135   int out;
136   int in;
137   int C;
138   int m;
139   for (int i = 0; i < reverse->data_size_; ++i) {
140     int tmp = i;
141     for (int j = 0; j < reverse_param->num_axis_; ++j) {
142       C = reverse->in_count_[j];
143       out = tmp / (C * reverse->strides_[j]);
144       in = tmp / reverse->strides_[j] - out * C;
145       m = tmp % reverse->strides_[j];
146       tmp = out * C * reverse->strides_[j] + reverse->strides_[j] * (C - 1 - in) + m;
147     }
148     reverse->tmp_[i] = tmp;
149   }
150 
151   return NNACL_OK;
152 }
153 
CreateReverse(OpParameter * param,int data_type)154 KernelBase *CreateReverse(OpParameter *param, int data_type) {
155   ReverseStruct *reverse = (ReverseStruct *)malloc(sizeof(ReverseStruct));
156   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reverse);
157   memset(reverse, 0, sizeof(ReverseStruct));
158   reverse->base_.Release = ReverseRelease;
159   reverse->base_.Prepare = ReversePrepare;
160   reverse->base_.Resize = ReverseResize;
161   reverse->base_.Compute = ReverseCompute;
162   return (KernelBase *)reverse;
163 }
164 
165 REG_KERNEL_CREATOR(PrimType_ReverseV2, kNumberTypeFloat32, CreateReverse)
166 REG_KERNEL_CREATOR(PrimType_ReverseV2, kNumberTypeInt32, CreateReverse)
167