• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/kernel/arm/int8/softmax_int8.h"
18 #include <limits>
19 #include "nnacl/int8/softmax_int8.h"
20 #include "schema/model_generated.h"
21 #include "include/errorcode.h"
22 #include "src/kernel_registry.h"
23 
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_OK;
26 
27 using mindspore::lite::KernelRegistrar;
28 using mindspore::lite::RET_NULL_PTR;
29 using mindspore::schema::PrimitiveType_Softmax;
30 
31 namespace mindspore::kernel {
~SoftmaxInt8CPUKernel()32 SoftmaxInt8CPUKernel::~SoftmaxInt8CPUKernel() {
33   if (quant_param_ != nullptr) {
34     free(quant_param_);
35     quant_param_ = nullptr;
36   }
37 }
38 
Init()39 int SoftmaxInt8CPUKernel::Init() {
40   auto ret = SoftmaxBaseCPUKernel::Init();
41   if (ret != RET_OK) {
42     return ret;
43   }
44   quant_param_ = reinterpret_cast<SoftmaxQuantArg *>(malloc(sizeof(SoftmaxQuantArg)));
45   if (quant_param_ == nullptr) {
46     MS_LOG(ERROR) << "Malloc SoftmaxQuantArg for Softmax int8 op failed!";
47     return RET_ERROR;
48   }
49 
50   auto *input_tensor = in_tensors_.at(kInputIndex);
51   MS_ASSERT(input_tensor != nullptr);
52 
53   auto in_quant_args = input_tensor->quant_params();
54   quant_param_->in_quant_args_.scale_ = in_quant_args.front().scale;
55   quant_param_->in_quant_args_.zp_ = -in_quant_args.front().zeroPoint;
56 
57   auto *out_tensor = out_tensors_.at(kOutputIndex);
58   MS_ASSERT(out_tensor != nullptr);
59 
60   auto out_quant_args = out_tensor->quant_params();
61   quant_param_->out_quant_arg_.scale_ = out_quant_args.front().scale;
62   quant_param_->out_quant_arg_.zp_ = -out_quant_args.front().zeroPoint;
63   quant_param_->output_activation_min_ = std::numeric_limits<int8_t>::min();
64   quant_param_->output_activation_max_ = std::numeric_limits<int8_t>::max();
65 
66   const double input_real_multiplier =
67     MSMIN(quant_param_->in_quant_args_.scale_ * (1 << (unsigned int)(31 - 5)), (1LL << 31) - 1.0);
68   int right_shift = 0;
69   QuantizeMultiplierSmallerThanOne(input_real_multiplier, &quant_param_->output_multiplier_, &right_shift);
70   quant_param_->shift_left_ = right_shift < 0 ? -right_shift : 0;
71   quant_param_->shift_right_ = right_shift > 0 ? right_shift : 0;
72 
73   if (!InferShapeDone()) {
74     return RET_OK;
75   }
76 
77   return ReSize();
78 }
79 
ReSize()80 int SoftmaxInt8CPUKernel::ReSize() { return SoftmaxBaseCPUKernel::ReSize(); }
81 
DoSoftmax(int task_id)82 int SoftmaxInt8CPUKernel::DoSoftmax(int task_id) {
83   MS_ASSERT(in_tensors_.size() == 1);
84   MS_ASSERT(out_tensors_.size() == 1);
85 
86   auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->MutableData());
87   MS_ASSERT(input_ptr);
88   auto output_ptr = reinterpret_cast<int8_t *>(out_tensors_.at(0)->MutableData());
89   MS_ASSERT(output_ptr);
90 
91   int outter_size = 1, inner_size = 1;
92   for (int i = 0; i < softmax_param_->axis_; i++) {
93     outter_size *= softmax_param_->input_shape_[i];
94   }
95   for (int i = softmax_param_->axis_; i < softmax_param_->n_dim_; i++) {
96     inner_size *= softmax_param_->input_shape_[i];
97   }
98 
99   int stride = UP_DIV(outter_size, thread_count_);
100   if (INT_MUL_OVERFLOW(task_id, stride)) {
101     MS_LOG(ERROR) << "int mul overflow.";
102     return RET_ERROR;
103   }
104   int count = MSMIN(stride, outter_size - stride * task_id);
105   int stride_size = stride * task_id * inner_size;
106 
107   auto error_code = SoftmaxInt8(input_ptr + stride_size, output_ptr + stride_size, count, exp_data_ + stride_size,
108                                 sum_data_, quant_param_, softmax_param_);
109   if (error_code != RET_OK) {
110     MS_LOG(ERROR) << "DoSoftmax error task_id[" << task_id << "] error_code[" << error_code << "]";
111     return RET_ERROR;
112   }
113   return RET_OK;
114 }
115 
SoftmaxRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)116 int SoftmaxRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
117   CHECK_NULL_RETURN(cdata);
118   auto softmax_kernel = reinterpret_cast<SoftmaxInt8CPUKernel *>(cdata);
119   auto error_code = softmax_kernel->DoSoftmax(task_id);
120   if (error_code != RET_OK) {
121     MS_LOG(ERROR) << "SoftmaxRun error task_id[" << task_id << "] error_code[" << error_code << "]";
122     return RET_ERROR;
123   }
124   return RET_OK;
125 }
126 
Run()127 int SoftmaxInt8CPUKernel::Run() {
128   CHECK_LESS_RETURN(MAX_MALLOC_SIZE, softmax_param_->element_size_ * sizeof(int));
129   exp_data_ = reinterpret_cast<int *>(ms_context_->allocator->Malloc(softmax_param_->element_size_ * sizeof(int)));
130   int inner_size = 1;
131   for (int i = softmax_param_->axis_ + 1; i < softmax_param_->n_dim_; i++) {
132     if (INT_MUL_OVERFLOW(inner_size, softmax_param_->input_shape_[i])) {
133       MS_LOG(ERROR) << "int mul overflow.";
134       return RET_ERROR;
135     }
136     inner_size *= softmax_param_->input_shape_[i];
137   }
138   sum_data_ = reinterpret_cast<int *>(ms_context_->allocator->Malloc(inner_size * sizeof(int)));
139   if (exp_data_ == nullptr || sum_data_ == nullptr) {
140     MS_LOG(ERROR) << "Memory allocation failed";
141     ms_context_->allocator->Free(exp_data_);
142     ms_context_->allocator->Free(sum_data_);
143     return RET_ERROR;
144   }
145   auto ret = ParallelLaunch(this->ms_context_, SoftmaxRun, this, thread_count_);
146   ms_context_->allocator->Free(exp_data_);
147   ms_context_->allocator->Free(sum_data_);
148   if (ret != RET_OK) {
149     MS_LOG(ERROR) << "Softmax function error error_code[" << ret << "]";
150   }
151   return ret;
152 }
153 
154 REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Softmax, LiteKernelCreator<SoftmaxInt8CPUKernel>)
155 }  // namespace mindspore::kernel
156